@@ -73,13 +73,10 @@ Uses block y index to decide which values to operate on.
73
73
sync_threads ()
74
74
blockIdx_yz = (blockIdx (). z - 1 i32) * gridDim (). y + blockIdx (). y
75
75
idx0 = lo + (blockIdx_yz - 1 i32) * blockDim (). x + threadIdx (). x
76
- val = if idx0 <= hi
77
- values[idx0]
78
- else
79
- Ref {eltype(values)} ()[] # undef
80
- # if idx0 > hi, val, comparison and dest_idx are unused
76
+ if idx0 <= hi
77
+ val = values[idx0]
78
+ comparison = flex_lt (pivot, val, parity, lt, by)
81
79
end
82
- comparison = flex_lt (pivot, val, parity, lt, by)
83
80
84
81
@inbounds if idx0 <= hi
85
82
sums[threadIdx (). x] = 1 & comparison
@@ -90,9 +87,11 @@ Uses block y index to decide which values to operate on.
90
87
91
88
cumsum! (sums)
92
89
93
- dest_idx = @inbounds comparison ? blockDim (). x - sums[end ] + sums[threadIdx (). x] : threadIdx (). x - sums[threadIdx (). x]
94
- @inbounds if idx0 <= hi && dest_idx <= length (swap)
95
- swap[dest_idx] = val
90
+ @inbounds if idx0 <= hi
91
+ dest_idx = @inbounds comparison ? blockDim (). x - sums[end ] + sums[threadIdx (). x] : threadIdx (). x - sums[threadIdx (). x]
92
+ if dest_idx <= length (swap)
93
+ swap[dest_idx] = val
94
+ end
96
95
end
97
96
sync_threads ()
98
97
@@ -185,10 +184,8 @@ Must only run on 1 SM.
185
184
c = n_eff () - d
186
185
to_move = min (b, c)
187
186
sync_threads ()
188
- swap = if threadIdx (). x <= to_move
189
- vals[lo + a + threadIdx (). x]
190
- else
191
- Ref {eltype(vals)} ()[] # undef
187
+ if threadIdx (). x <= to_move
188
+ swap = vals[lo + a + threadIdx (). x]
192
189
end
193
190
sync_threads ()
194
191
if threadIdx (). x <= to_move
@@ -242,10 +239,8 @@ function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, b
242
239
to_swap = (i & k) == 0 && bitonic_lt (l, i) || (i & k) != 0 && bitonic_lt (i, l)
243
240
to_swap = to_swap == (i < l)
244
241
245
- old_val = if to_swap
246
- @inbounds swap[l + 1 ]
247
- else
248
- Ref {eltype(swap)} ()[] # undef
242
+ if to_swap
243
+ old_val = @inbounds swap[l + 1 ]
249
244
end
250
245
sync_threads ()
251
246
if to_swap
@@ -275,10 +270,8 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
275
270
for level in 0 : L
276
271
# get left/right neighbor depending on even/odd level
277
272
buddy = threadIdx (). x - 1 i32 + 2 i32 * (1 i32 & (threadIdx (). x % 2 i32 != level % 2 i32))
278
- buddy_val = if 1 <= buddy <= L && threadIdx (). x <= L
279
- swap[buddy]
280
- else
281
- Ref {eltype(swap)} ()[] # undef
273
+ if 1 <= buddy <= L && threadIdx (). x <= L
274
+ buddy_val = swap[buddy]
282
275
end
283
276
sync_threads ()
284
277
if 1 <= buddy <= L && threadIdx (). x <= L
0 commit comments