Skip to content

Commit b1ebbd2

Browse files
mcabbottmaleadt
authored andcommitted
Allow sorting tuples by avoiding calls to one/zero.
1 parent 94ba745 commit b1ebbd2

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

src/sorting.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ Uses block y index to decide which values to operate on.
7373
sync_threads()
7474
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
7575
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
76-
val = idx0 <= hi ? values[idx0] : one(eltype(values))
77-
comparison = flex_lt(pivot, val, parity, lt, by)
76+
if idx0 <= hi
77+
val = values[idx0]
78+
comparison = flex_lt(pivot, val, parity, lt, by)
79+
end
7880

7981
@inbounds if idx0 <= hi
8082
sums[threadIdx().x] = 1 & comparison
@@ -85,9 +87,11 @@ Uses block y index to decide which values to operate on.
8587

8688
cumsum!(sums)
8789

88-
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
89-
@inbounds if idx0 <= hi && dest_idx <= length(swap)
90-
swap[dest_idx] = val
90+
@inbounds if idx0 <= hi
91+
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
92+
if dest_idx <= length(swap)
93+
swap[dest_idx] = val
94+
end
9195
end
9296
sync_threads()
9397

@@ -180,10 +184,8 @@ Must only run on 1 SM.
180184
c = n_eff() - d
181185
to_move = min(b, c)
182186
sync_threads()
183-
swap = if threadIdx().x <= to_move
184-
vals[lo + a + threadIdx().x]
185-
else
186-
zero(eltype(vals)) # unused value
187+
if threadIdx().x <= to_move
188+
swap = vals[lo + a + threadIdx().x]
187189
end
188190
sync_threads()
189191
if threadIdx().x <= to_move
@@ -215,7 +217,6 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
215217

216218
@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
217219
sync_threads()
218-
old_val = zero(eltype(swap))
219220

220221
log_blockDim = begin
221222
out = 0
@@ -269,10 +270,8 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
269270
for level in 0:L
270271
# get left/right neighbor depending on even/odd level
271272
buddy = threadIdx().x - 1i32 + 2i32 * (1i32 & (threadIdx().x % 2i32 != level % 2i32))
272-
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
273-
swap[buddy]
274-
else
275-
zero(eltype(swap)) # unused value
273+
if 1 <= buddy <= L && threadIdx().x <= L
274+
buddy_val = swap[buddy]
276275
end
277276
sync_threads()
278277
if 1 <= buddy <= L && threadIdx().x <= L

test/sorting.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ end
302302
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.QuickSort)
303303
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.QuickSort)
304304
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.QuickSort)
305+
@test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.QuickSort)
305306

306307
# non-uniform distributions
307308
@test check_sort!(UInt8, 100000, x -> round(255 * rand() ^ 2); alg=CUDA.QuickSort)
@@ -345,6 +346,7 @@ end
345346
@test check_sort!(Float64, 10000, x -> rand(Float64); alg=CUDA.BitonicSort)
346347
@test check_sort!(Float32, 10000, x -> rand(Float32); alg=CUDA.BitonicSort)
347348
@test check_sort!(Float16, 10000, x -> rand(Float16); alg=CUDA.BitonicSort)
349+
@test check_sort!(Tuple{Int,Int}, 10000, x -> (rand(Int), rand(Int)); alg=CUDA.BitonicSort)
348350

349351
# test various sizes
350352
@test check_sort!(Float32, 1, x -> rand(Float32); alg=CUDA.BitonicSort)

0 commit comments

Comments
 (0)