Skip to content

Commit 07b50a4

Browse files
committed
allow tuples
1 parent ca9034d commit 07b50a4

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/sorting.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ using ..CUDA: i32
3737
(eq && a′ == b′) || lt(a′, b′)
3838
end
3939

40+
# To allow sorting tuples of numbers:
41+
@inline _zero(x) = Base.zero(x)
42+
@inline _zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
43+
44+
@inline _one(x) = Base.one(x)
45+
@inline _one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
46+
4047

4148
# Batch partitioning
4249
"""
@@ -73,7 +80,7 @@ Uses block y index to decide which values to operate on.
7380
sync_threads()
7481
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
7582
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
76-
val = idx0 <= hi ? values[idx0] : one(eltype(values))
83+
val = idx0 <= hi ? values[idx0] : _one(eltype(values))
7784
comparison = flex_lt(pivot, val, parity, lt, by)
7885

7986
@inbounds if idx0 <= hi
@@ -183,7 +190,7 @@ Must only run on 1 SM.
183190
swap = if threadIdx().x <= to_move
184191
vals[lo + a + threadIdx().x]
185192
else
186-
zero(eltype(vals)) # unused value
193+
_zero(eltype(vals)) # unused value
187194
end
188195
sync_threads()
189196
if threadIdx().x <= to_move
@@ -215,7 +222,7 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
215222

216223
@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
217224
sync_threads()
218-
old_val = zero(eltype(swap))
225+
old_val = _zero(eltype(swap))
219226

220227
log_blockDim = begin
221228
out = 0
@@ -272,7 +279,7 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
272279
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
273280
swap[buddy]
274281
else
275-
zero(eltype(swap)) # unused value
282+
_zero(eltype(swap)) # unused value
276283
end
277284
sync_threads()
278285
if 1 <= buddy <= L && threadIdx().x <= L

0 commit comments

Comments
 (0)