@@ -37,6 +37,13 @@ using ..CUDA: i32
37
37
(eq && a′ == b′) || lt (a′, b′)
38
38
end
39
39
40
+ # To allow sorting tuples of numbers:
41
+ @inline _zero (x) = Base. zero (x)
42
+ @inline _zero (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> zero (T. parameters[i]), N)
43
+
44
+ @inline _one (x) = Base. one (x)
45
+ @inline _one (:: Type{T} ) where {T<: Tuple{Vararg{Any,N}} } where {N} = ntuple (i -> one (T. parameters[i]), N)
46
+
40
47
41
48
# Batch partitioning
42
49
"""
@@ -73,7 +80,7 @@ Uses block y index to decide which values to operate on.
73
80
sync_threads ()
74
81
blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
75
82
idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
76
- val = idx0 <= hi ? values[idx0] : one (eltype (values))
83
+ val = idx0 <= hi ? values[idx0] : _one (eltype (values))
77
84
comparison = flex_lt (pivot, val, parity, lt, by)
78
85
79
86
@inbounds if idx0 <= hi
@@ -183,7 +190,7 @@ Must only run on 1 SM.
183
190
swap = if threadIdx (). x <= to_move
184
191
vals[lo + a + threadIdx (). x]
185
192
else
186
- zero (eltype (vals)) # unused value
193
+ _zero (eltype (vals)) # unused value
187
194
end
188
195
sync_threads ()
189
196
if threadIdx (). x <= to_move
@@ -215,7 +222,7 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
215
222
216
223
@inbounds swap[threadIdx (). x] = vals[lo + threadIdx (). x * stride]
217
224
sync_threads ()
218
- old_val = zero (eltype (swap))
225
+ old_val = _zero (eltype (swap))
219
226
220
227
log_blockDim = begin
221
228
out = 0
@@ -272,7 +279,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
272
279
buddy_val = if 1 <= buddy <= L && threadIdx (). x <= L
273
280
swap[buddy]
274
281
else
275
- zero (eltype (swap)) # unused value
282
+ _zero (eltype (swap)) # unused value
276
283
end
277
284
sync_threads ()
278
285
if 1 <= buddy <= L && threadIdx (). x <= L
0 commit comments