64
64
return lo, n
65
65
end
66
66
67
+ @inline function rev_lt (a :: T , b :: T , lt, rev :: Val{R} ) where {T,R}
68
+ if R
69
+ return lt (b, a)
70
+ else
71
+ return lt (a, b)
72
+ end
73
+ end
74
+
75
+ @inline function rev_lt (a :: Tuple{T, J} , b :: Tuple{T, J} , lt, rev :: Val{R} ) where {T, J, R}
76
+ if R
77
+ if a[1 ] == b[1 ]
78
+ return a[2 ] < b[2 ]
79
+ else
80
+ return lt (b[1 ], a[1 ])
81
+ end
82
+ else
83
+ return lt (a, b)
84
+ end
85
+ end
86
+
67
87
# Functions specifically for "large" bitonic steps (those that cannot use shmem)
68
88
69
89
70
- @inline function compare! (vals:: AbstractArray{T} , i1:: I , i2:: I , dir:: Bool , by, lt) where {T,I<: Integer }
90
+ @inline function compare! (vals:: AbstractArray{T} , i1:: I , i2:: I , dir:: Bool , by, lt, rev ) where {T,I}
71
91
i1′, i2′ = i1 + one (I), i2 + one (I)
72
- @inbounds if dir != lt (by (vals[i1′]), by (vals[i2′]))
92
+ @inbounds if dir != rev_lt (by (vals[i1′]), by (vals[i2′]), lt, rev )
73
93
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
74
94
end
75
95
end
76
96
77
- @inline function compare! (vals_inds:: Tuple , i1:: I , i2:: I , dir:: Bool , by, lt) where I
97
+ @inline function compare! (vals_inds:: Tuple , i1:: I , i2:: I , dir:: Bool , by, lt, rev ) where I
78
98
i1′, i2′ = i1 + one (I), i2 + one (I)
79
99
vals, inds = vals_inds
80
100
# comparing tuples of (value, index) guarantees stability of sort
81
- @inbounds if dir != lt ((by (vals[inds[i1′]]), inds[i1′]), (by (vals[inds[i2′]]), inds[i2′]))
101
+ @inbounds if dir != rev_lt ((by (vals[inds[i1′]]), inds[i1′]), (by (vals[inds[i2′]]), inds[i2′]), lt, rev )
82
102
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
83
103
end
84
104
end
85
105
86
106
87
- @inline function get_range_part1 (n:: I , index:: I , k:: I ):: Tuple{I,I,Bool} where {I <: Integer }
107
+ @inline function get_range_part1 (n:: I , index:: I , k:: I ):: Tuple{I,I,Bool} where I
88
108
lo = zero (I)
89
109
dir = true
90
110
for iter = one (I): k- one (I)
@@ -130,7 +150,7 @@ Note that to avoid synchronization issues, only one thread from each pair of
130
150
indices being swapped will actually move data. This does mean half of the threads
131
151
do nothing, but it works for non-power2 arrays while allowing direct indexing.
132
152
"""
133
- function comparator_kernel (vals, length_vals:: I , k:: I , j:: I , by:: F1 , lt:: F2 ) where {I,F1,F2}
153
+ function comparator_kernel (vals, length_vals:: I , k:: I , j:: I , by:: F1 , lt:: F2 , rev ) where {I,F1,F2}
134
154
index = (blockDim (). x * (blockIdx (). x - one (I))) + threadIdx (). x - one (I)
135
155
136
156
lo, n, dir = get_range (length_vals, index, k, j)
@@ -139,7 +159,7 @@ function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2) whe
139
159
m = gp2lt (n)
140
160
if lo <= index < lo + n - m
141
161
i1, i2 = index, index + m
142
- @inbounds compare! (vals, i1, i2, dir, by, lt)
162
+ @inbounds compare! (vals, i1, i2, dir, by, lt, rev )
143
163
end
144
164
end
145
165
return
@@ -148,18 +168,18 @@ end
148
168
149
169
# Functions for "small" bitonic steps (those that can use shmem)
150
170
151
- @inline function compare_small! (vals:: AbstractArray{T} , i1:: I , i2:: I , dir:: Bool , by, lt) where {T,I<: Integer }
171
+ @inline function compare_small! (vals:: AbstractArray{T} , i1:: I , i2:: I , dir:: Bool , by, lt, rev ) where {T,I}
152
172
i1′, i2′ = i1 + one (I), i2 + one (I)
153
- @inbounds if dir != lt (by (vals[i1′]), by (vals[i2′]))
173
+ @inbounds if dir != rev_lt (by (vals[i1′]), by (vals[i2′]), lt, rev )
154
174
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
155
175
end
156
176
end
157
177
158
- @inline function compare_small! (vals_inds:: Tuple , i1:: I , i2:: I , dir:: Bool , by, lt) where I
178
+ @inline function compare_small! (vals_inds:: Tuple , i1:: I , i2:: I , dir:: Bool , by, lt, rev ) where I
159
179
i1′, i2′ = i1 + one (I), i2 + one (I)
160
180
vals, inds = vals_inds
161
181
# comparing tuples of (value, index) guarantees stability of sort
162
- @inbounds if dir != lt ((by (vals[i1′]), inds[i1′]), (by (vals[i2′]), inds[i2′]))
182
+ @inbounds if dir != rev_lt ((by (vals[i1′]), inds[i1′]), (by (vals[i2′]), inds[i2′]), lt, rev )
163
183
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
164
184
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
165
185
end
@@ -172,7 +192,7 @@ all threads perform swaps accessible using shmem.
172
192
173
193
Various negative exit values just for debugging.
174
194
"""
175
- @inline function block_range (n:: I , block_index:: I , k:: I , j:: I ):: Tuple{I,I,Bool} where {I <: Integer }
195
+ @inline function block_range (n:: I , block_index:: I , k:: I , j:: I ):: Tuple{I,I,Bool} where I
176
196
lo = zero (I)
177
197
dir = true
178
198
tmp = block_index * two (I)
@@ -236,7 +256,7 @@ array. Each view is indexed along block x dim: one view per pseudo-block
236
256
vals_inds:: Tuple{AbstractArray{T},AbstractArray{J}} ,
237
257
index,
238
258
in_range,
239
- ) where {T,J<: Integer }
259
+ ) where {T,J}
240
260
# NB: I tried creating both shmem arrays with `initialize_shmem!`
241
261
# but the behavior changed - maybe it's necessary to alloc both before
242
262
# writing to either?
@@ -284,7 +304,7 @@ This is captured by `pseudo_block_idx`.
284
304
Note that this moves the array values copied within shmem, but doesn't copy them
285
305
back to global the way it does for indices.
286
306
"""
287
- function comparator_small_kernel (c, length_c:: I , k:: I , j_0:: I , j_f:: I , by:: F1 , lt:: F2 ) where {I,F1,F2}
307
+ function comparator_small_kernel (c, length_c:: I , k:: I , j_0:: I , j_f:: I , by:: F1 , lt:: F2 , rev ) where {I,F1,F2}
288
308
pseudo_block_idx = (blockIdx (). x - one (I)) * blockDim (). y + threadIdx (). y - one (I)
289
309
# immutable info about the range used by this kernel
290
310
_lo, _n, dir = block_range (length_c, pseudo_block_idx, k, j_0)
@@ -301,7 +321,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, by::F1, l
301
321
m = gp2lt (n)
302
322
if lo <= index < lo + n - m
303
323
i1, i2 = index - _lo, index - _lo + m
304
- compare_small! (swap, i1, i2, dir, by, lt)
324
+ compare_small! (swap, i1, i2, dir, by, lt, rev )
305
325
end
306
326
end
307
327
lo, n = bisect_range (index, lo, n)
@@ -322,7 +342,13 @@ function bitonic_shmem(c, threads)
322
342
return prod (threads) * sum (map (a -> sizeof (eltype (a)), c))
323
343
end
324
344
325
- function bitonic_sort! (c; by = identity, lt = isless) where {T}
345
+ """
346
+ Call bitonic sort on `c` which can be a CuArray of values to `sort!` or a tuple
347
+ of values and an index array for doing `sortperm!`. Cannot provide a stable
348
+ `sort!` although `sortperm!` is properly stable. To reverse, set `rev=true`
349
+ rather than `lt=!isless` (otherwise stability of sortperm breaks down).
350
+ """
351
+ function bitonic_sort! (c; by = identity, lt = isless, rev= false ) where {T}
326
352
c_len = if typeof (c) <: Tuple
327
353
length (c[1 ])
328
354
else
@@ -341,12 +367,12 @@ function bitonic_sort!(c; by = identity, lt = isless) where {T}
341
367
for j = 1 : j_final
342
368
343
369
# use Int32 args for indexing --> ~10% faster kernels
344
- args1 = (c, map (Int32, (c_len, k, j, j_final))... , by, lt)
370
+ args1 = (c, map (Int32, (c_len, k, j, j_final))... , by, lt, Val (rev) )
345
371
kernel1 = @cuda launch = false comparator_small_kernel (args1... )
346
372
config1 = launch_configuration (kernel1. fun, shmem = threads -> bitonic_shmem (c, threads))
347
373
threads1 = prevpow (2 , config1. threads)
348
374
349
- args2 = (c, map (Int32, (c_len, k, j))... , by, lt)
375
+ args2 = (c, map (Int32, (c_len, k, j))... , by, lt, Val (rev) )
350
376
kernel2 = @cuda launch = false comparator_kernel (args2... )
351
377
config2 = launch_configuration (kernel2. fun, shmem = threads -> bitonic_shmem (c, threads))
352
378
threads2 = prevpow (2 , config2. threads)
@@ -372,10 +398,3 @@ function bitonic_sort!(c; by = identity, lt = isless) where {T}
372
398
end
373
399
end
374
400
end
375
-
376
- # a = rand(Float32, 1_000_000)
377
- # c = CuArray(a)
378
- # I = CuArray(collect(1:length(c)))
379
- # bitonic_sort!((c, I))
380
- # synchronize()
381
- # @assert c[I] |> Array == sort(a)
0 commit comments