Skip to content

Commit 30db77e

Browse files
committed
stable reverse
1 parent 8f8fb71 commit 30db77e

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed

src/bitonic_sort/gpu.jl

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,47 @@ end
6464
return lo, n
6565
end
6666

67+
@inline function rev_lt(a :: T, b :: T, lt, rev :: Val{R}) where {T,R}
68+
if R
69+
return lt(b, a)
70+
else
71+
return lt(a, b)
72+
end
73+
end
74+
75+
@inline function rev_lt(a :: Tuple{T, J}, b :: Tuple{T, J}, lt, rev :: Val{R}) where {T, J, R}
76+
if R
77+
if a[1] == b[1]
78+
return a[2] < b[2]
79+
else
80+
return lt(b[1], a[1])
81+
end
82+
else
83+
return lt(a, b)
84+
end
85+
end
86+
6787
# Functions specifically for "large" bitonic steps (those that cannot use shmem)
6888

6989

70-
@inline function compare!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt) where {T,I<:Integer}
90+
@inline function compare!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt, rev) where {T,I}
7191
i1′, i2′ = i1 + one(I), i2 + one(I)
72-
@inbounds if dir != lt(by(vals[i1′]), by(vals[i2′]))
92+
@inbounds if dir != rev_lt(by(vals[i1′]), by(vals[i2′]), lt, rev)
7393
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
7494
end
7595
end
7696

77-
@inline function compare!(vals_inds::Tuple, i1::I, i2::I, dir::Bool, by, lt) where I
97+
@inline function compare!(vals_inds::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev) where I
7898
i1′, i2′ = i1 + one(I), i2 + one(I)
7999
vals, inds = vals_inds
80100
# comparing tuples of (value, index) guarantees stability of sort
81-
@inbounds if dir != lt((by(vals[inds[i1′]]), inds[i1′]), (by(vals[inds[i2′]]), inds[i2′]))
101+
@inbounds if dir != rev_lt((by(vals[inds[i1′]]), inds[i1′]), (by(vals[inds[i2′]]), inds[i2′]), lt, rev)
82102
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
83103
end
84104
end
85105

86106

87-
@inline function get_range_part1(n::I, index::I, k::I)::Tuple{I,I,Bool} where {I<:Integer}
107+
@inline function get_range_part1(n::I, index::I, k::I)::Tuple{I,I,Bool} where I
88108
lo = zero(I)
89109
dir = true
90110
for iter = one(I):k-one(I)
@@ -130,7 +150,7 @@ Note that to avoid synchronization issues, only one thread from each pair of
130150
indices being swapped will actually move data. This does mean half of the threads
131151
do nothing, but it works for non-power2 arrays while allowing direct indexing.
132152
"""
133-
function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2) where {I,F1,F2}
153+
function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2, rev) where {I,F1,F2}
134154
index = (blockDim().x * (blockIdx().x - one(I))) + threadIdx().x - one(I)
135155

136156
lo, n, dir = get_range(length_vals, index, k, j)
@@ -139,7 +159,7 @@ function comparator_kernel(vals, length_vals::I, k::I, j::I, by::F1, lt::F2) whe
139159
m = gp2lt(n)
140160
if lo <= index < lo + n - m
141161
i1, i2 = index, index + m
142-
@inbounds compare!(vals, i1, i2, dir, by, lt)
162+
@inbounds compare!(vals, i1, i2, dir, by, lt, rev)
143163
end
144164
end
145165
return
@@ -148,18 +168,18 @@ end
148168

149169
# Functions for "small" bitonic steps (those that can use shmem)
150170

151-
@inline function compare_small!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt) where {T,I<:Integer}
171+
@inline function compare_small!(vals::AbstractArray{T}, i1::I, i2::I, dir::Bool, by, lt, rev) where {T,I}
152172
i1′, i2′ = i1 + one(I), i2 + one(I)
153-
@inbounds if dir != lt(by(vals[i1′]), by(vals[i2′]))
173+
@inbounds if dir != rev_lt(by(vals[i1′]), by(vals[i2′]), lt, rev)
154174
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
155175
end
156176
end
157177

158-
@inline function compare_small!(vals_inds::Tuple, i1::I, i2::I, dir::Bool, by, lt) where I
178+
@inline function compare_small!(vals_inds::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev) where I
159179
i1′, i2′ = i1 + one(I), i2 + one(I)
160180
vals, inds = vals_inds
161181
# comparing tuples of (value, index) guarantees stability of sort
162-
@inbounds if dir != lt((by(vals[i1′]), inds[i1′]), (by(vals[i2′]), inds[i2′]))
182+
@inbounds if dir != rev_lt((by(vals[i1′]), inds[i1′]), (by(vals[i2′]), inds[i2′]), lt, rev)
163183
vals[i1′], vals[i2′] = vals[i2′], vals[i1′]
164184
inds[i1′], inds[i2′] = inds[i2′], inds[i1′]
165185
end
@@ -172,7 +192,7 @@ all threads perform swaps accessible using shmem.
172192
173193
Various negative exit values just for debugging.
174194
"""
175-
@inline function block_range(n::I, block_index::I, k::I, j::I)::Tuple{I,I,Bool} where {I<:Integer}
195+
@inline function block_range(n::I, block_index::I, k::I, j::I)::Tuple{I,I,Bool} where I
176196
lo = zero(I)
177197
dir = true
178198
tmp = block_index * two(I)
@@ -236,7 +256,7 @@ array. Each view is indexed along block x dim: one view per pseudo-block
236256
vals_inds::Tuple{AbstractArray{T},AbstractArray{J}},
237257
index,
238258
in_range,
239-
) where {T,J<:Integer}
259+
) where {T,J}
240260
# NB: I tried creating both shmem arrays with `initialize_shmem!`
241261
# but the behavior changed - maybe it's necessary to alloc both before
242262
# writing to either?
@@ -284,7 +304,7 @@ This is captured by `pseudo_block_idx`.
284304
Note that this moves the array values copied within shmem, but doesn't copy them
285305
back to global the way it does for indices.
286306
"""
287-
function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, by::F1, lt::F2) where {I,F1,F2}
307+
function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, by::F1, lt::F2, rev) where {I,F1,F2}
288308
pseudo_block_idx = (blockIdx().x - one(I)) * blockDim().y + threadIdx().y - one(I)
289309
# immutable info about the range used by this kernel
290310
_lo, _n, dir = block_range(length_c, pseudo_block_idx, k, j_0)
@@ -301,7 +321,7 @@ function comparator_small_kernel(c, length_c::I, k::I, j_0::I, j_f::I, by::F1, l
301321
m = gp2lt(n)
302322
if lo <= index < lo + n - m
303323
i1, i2 = index - _lo, index - _lo + m
304-
compare_small!(swap, i1, i2, dir, by, lt)
324+
compare_small!(swap, i1, i2, dir, by, lt, rev)
305325
end
306326
end
307327
lo, n = bisect_range(index, lo, n)
@@ -322,7 +342,13 @@ function bitonic_shmem(c, threads)
322342
return prod(threads) * sum(map(a -> sizeof(eltype(a)), c))
323343
end
324344

325-
function bitonic_sort!(c; by = identity, lt = isless) where {T}
345+
"""
346+
Call bitonic sort on `c` which can be a CuArray of values to `sort!` or a tuple
347+
of values and an index array for doing `sortperm!`. Cannot provide a stable
348+
`sort!` although `sortperm!` is properly stable. To reverse, set `rev=true`
349+
rather than `lt=!isless` (otherwise stability of sortperm breaks down).
350+
"""
351+
function bitonic_sort!(c; by = identity, lt = isless, rev=false) where {T}
326352
c_len = if typeof(c) <: Tuple
327353
length(c[1])
328354
else
@@ -341,12 +367,12 @@ function bitonic_sort!(c; by = identity, lt = isless) where {T}
341367
for j = 1:j_final
342368

343369
# use Int32 args for indexing --> ~10% faster kernels
344-
args1 = (c, map(Int32, (c_len, k, j, j_final))..., by, lt)
370+
args1 = (c, map(Int32, (c_len, k, j, j_final))..., by, lt, Val(rev))
345371
kernel1 = @cuda launch = false comparator_small_kernel(args1...)
346372
config1 = launch_configuration(kernel1.fun, shmem = threads -> bitonic_shmem(c, threads))
347373
threads1 = prevpow(2, config1.threads)
348374

349-
args2 = (c, map(Int32, (c_len, k, j))..., by, lt)
375+
args2 = (c, map(Int32, (c_len, k, j))..., by, lt, Val(rev))
350376
kernel2 = @cuda launch = false comparator_kernel(args2...)
351377
config2 = launch_configuration(kernel2.fun, shmem = threads -> bitonic_shmem(c, threads))
352378
threads2 = prevpow(2, config2.threads)
@@ -372,10 +398,3 @@ function bitonic_sort!(c; by = identity, lt = isless) where {T}
372398
end
373399
end
374400
end
375-
376-
#a = rand(Float32, 1_000_000)
377-
#c = CuArray(a)
378-
#I = CuArray(collect(1:length(c)))
379-
#bitonic_sort!((c, I))
380-
#synchronize()
381-
#@assert c[I] |> Array == sort(a)

0 commit comments

Comments
 (0)