Skip to content

Commit 55ac2fc

Browse files
committed
Do the naive thing.
1 parent 13b655c commit 55ac2fc

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

src/sorting.jl

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,10 @@ Uses block y index to decide which values to operate on.
7373
sync_threads()
7474
blockIdx_yz = (blockIdx().z - 1i32) * gridDim().y + blockIdx().y
7575
idx0 = lo + (blockIdx_yz - 1i32) * blockDim().x + threadIdx().x
76-
val = if idx0 <= hi
77-
values[idx0]
78-
else
79-
Ref{eltype(values)}()[] # undef
80-
# if idx0 > hi, val, comparison and dest_idx are unused
76+
if idx0 <= hi
77+
val = values[idx0]
78+
comparison = flex_lt(pivot, val, parity, lt, by)
8179
end
82-
comparison = flex_lt(pivot, val, parity, lt, by)
8380

8481
@inbounds if idx0 <= hi
8582
sums[threadIdx().x] = 1 & comparison
@@ -90,9 +87,11 @@ Uses block y index to decide which values to operate on.
9087

9188
cumsum!(sums)
9289

93-
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
94-
@inbounds if idx0 <= hi && dest_idx <= length(swap)
95-
swap[dest_idx] = val
90+
@inbounds if idx0 <= hi
91+
dest_idx = @inbounds comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
92+
if dest_idx <= length(swap)
93+
swap[dest_idx] = val
94+
end
9695
end
9796
sync_threads()
9897

@@ -185,10 +184,8 @@ Must only run on 1 SM.
185184
c = n_eff() - d
186185
to_move = min(b, c)
187186
sync_threads()
188-
swap = if threadIdx().x <= to_move
189-
vals[lo + a + threadIdx().x]
190-
else
191-
Ref{eltype(vals)}()[] # undef
187+
if threadIdx().x <= to_move
188+
swap = vals[lo + a + threadIdx().x]
192189
end
193190
sync_threads()
194191
if threadIdx().x <= to_move
@@ -242,10 +239,8 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
242239
to_swap = (i & k) == 0 && bitonic_lt(l, i) || (i & k) != 0 && bitonic_lt(i, l)
243240
to_swap = to_swap == (i < l)
244241

245-
old_val = if to_swap
246-
@inbounds swap[l + 1]
247-
else
248-
Ref{eltype(swap)}()[] # undef
242+
if to_swap
243+
old_val = @inbounds swap[l + 1]
249244
end
250245
sync_threads()
251246
if to_swap
@@ -275,10 +270,8 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
275270
for level in 0:L
276271
# get left/right neighbor depending on even/odd level
277272
buddy = threadIdx().x - 1i32 + 2i32 * (1i32 & (threadIdx().x % 2i32 != level % 2i32))
278-
buddy_val = if 1 <= buddy <= L && threadIdx().x <= L
279-
swap[buddy]
280-
else
281-
Ref{eltype(swap)}()[] # undef
273+
if 1 <= buddy <= L && threadIdx().x <= L
274+
buddy_val = swap[buddy]
282275
end
283276
sync_threads()
284277
if 1 <= buddy <= L && threadIdx().x <= L

0 commit comments

Comments
 (0)