|
1 | | -# Ported from CUDA.jl. |
2 | | -# Originally developed by @xaellison (Alex Ellison). |
3 | | - |
4 | | -Base.sort!(x::AnyROCArray; kwargs...) = bitonic_sort!(x; kwargs...) |
5 | | - |
6 | | -function Base.sortperm!( |
7 | | - ix::AnyROCArray, x::AnyROCArray; |
8 | | - initialized::Bool = false, kwargs..., |
9 | | -) |
10 | | - axes(ix) == axes(x) || throw(ArgumentError( |
11 | | - "Index array must have the same size as the source array, instead: " * |
12 | | - "$(size(ix)) vs $(size(x)).")) |
13 | | - |
14 | | - initialized || (ix .= LinearIndices(x);) |
15 | | - bitonic_sort!((x, ix); kwargs...) |
16 | | - return ix |
17 | | -end |
18 | | - |
19 | | -function Base.sortperm(x::AnyROCArray; kwargs...) |
20 | | - sortperm!(ROCArray(1:length(x)), x; initialized=true, kwargs...) |
21 | | -end |
22 | | - |
23 | | -# TODO dims |
24 | | -function bitonic_sort!(X; lt = isless, by = identity, rev::Bool = false) |
25 | | - _shmem(x::Tuple, groupsize) = prod(groupsize) * sum(sizeof.(eltype.(x))) |
26 | | - _shmem(x::AbstractArray, groupsize) = prod(groupsize) * sizeof(eltype(x)) |
27 | | - |
28 | | - len_x = typeof(X) <: Tuple ? length(X[1]) : length(X) |
29 | | - I = len_x ≤ typemax(Int32) ? Int32 : Int64 |
30 | | - threads = min(256, prevpow(2, len_x)) |
31 | | - |
32 | | - # Compile kernels. |
33 | | - ker_1 = @roc launch=false cmp_small_kern!( |
34 | | - X, I(len_x), one(I), one(I), one(I), by, lt, Val(rev)) |
35 | | - ker_2 = @roc launch=false cmp_ker!( |
36 | | - X, I(len_x), one(I), one(I), by, lt, Val(rev)) |
37 | | - |
38 | | - # Cutoff for when to use `ker_1` vs `ker_2`. |
39 | | - log_threads = Int(log2(threads)) |
40 | | - |
41 | | - k₀ = ceil(Int, log2(len_x)) |
42 | | - for k in k₀:-1:1 |
43 | | - j_end = k₀ - k + 1 |
44 | | - for j in 1:j_end |
45 | | - if k₀ - k - j + 2 ≤ log_threads |
46 | | - pseudo_block_len = 1 << abs(j_end + 1 - j) |
47 | | - n_pseudo_blocks = nextpow(2, len_x) ÷ pseudo_block_len |
48 | | - pseudo_blocks_per_block = threads ÷ pseudo_block_len |
49 | | - |
50 | | - gridsize = max(1, n_pseudo_blocks ÷ pseudo_blocks_per_block) |
51 | | - groupsize = (pseudo_block_len, threads ÷ pseudo_block_len) |
52 | | - ker_1( |
53 | | - X, I(len_x), I(k), I(j), I(j_end), by, lt, Val(rev); |
54 | | - gridsize, groupsize, shmem=_shmem(X, groupsize)) |
55 | | - else |
56 | | - gridsize = cld(len_x, threads) |
57 | | - ker_2( |
58 | | - X, I(len_x), I(k), I(j), by, lt, Val(rev); |
59 | | - gridsize, groupsize=threads) |
60 | | - end |
61 | | - end |
62 | | - end |
63 | | - return X |
64 | | -end |
65 | | - |
66 | | -function cmp_ker!(x, lenₓ::I, k::I, j::I, by, lt, rev) where I |
67 | | - idx::I = workgroupDim().x * (workgroupIdx().x - 0x1) + workitemIdx().x - 0x1 |
68 | | - lo, n, dir = get_range(lenₓ, idx, k, j) |
69 | | - |
70 | | - if !(lo < 0x0 || n < 0x0) && !(idx ≥ lenₓ) |
71 | | - m = gp2lt(n) |
72 | | - if lo ≤ idx < lo + n - m |
73 | | - i1, i2 = idx, idx + m |
74 | | - cmp!(x, i1, i2, dir, by, lt, rev) |
75 | | - end |
76 | | - end |
77 | | - return |
78 | | -end |
79 | | - |
80 | | -function cmp_small_kern!(x, lenₓ::I, k::I, j₀::I, jₑ::I, by, lt, rev) where I |
81 | | - bidx::I = (workgroupIdx().x - 0x1) * workgroupDim().y + workitemIdx().y - 0x1 |
82 | | - _lo, _n, dir = block_range(lenₓ, bidx, k, j₀) |
83 | | - |
84 | | - idx = _lo + I(workitemIdx().x) - 0x1 |
85 | | - in_range = workitemIdx().x ≤ _n && _lo ≥ 0x0 |
86 | | - swap = init_shmem(x, idx, in_range) |
87 | | - |
88 | | - lo, n = _lo, _n |
89 | | - for j in j₀:jₑ |
90 | | - if in_range && !(lo < 0x0 || n < 0x0) |
91 | | - m = gp2lt(n) |
92 | | - if lo ≤ idx < lo + n - m |
93 | | - i1, i2 = idx - _lo, idx - _lo + m |
94 | | - cmp_small!(swap, i1, i2, dir, by, lt, rev) |
95 | | - end |
96 | | - end |
97 | | - lo, n = bisect_range(idx, lo, n) |
98 | | - sync_workgroup() |
99 | | - end |
100 | | - finalize_shmem!(x, swap, idx, in_range) |
101 | | - return |
102 | | -end |
103 | | - |
104 | | -function bisect_range(idx::I, lo::I, n::I) where I |
105 | | - n ≤ 0x1 && return -one(I), -one(I) |
106 | | - |
107 | | - m = gp2lt(n) |
108 | | - if idx < lo + m |
109 | | - n = m |
110 | | - else |
111 | | - lo += m |
112 | | - n -= m |
113 | | - end |
114 | | - lo, n |
115 | | -end |
116 | | - |
117 | | -function cmp!( |
118 | | - x::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev, |
119 | | -) where I |
120 | | - i1, i2 = i1 + one(I), i2 + one(I) |
121 | | - @inbounds if dir != _lt_fn(by(x[i1]), by(x[i2]), lt, rev) |
122 | | - x[i1], x[i2] = x[i2], x[i1] |
123 | | - end |
124 | | -end |
125 | | - |
126 | | -function cmp!( |
127 | | - X::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev, |
128 | | -) where I |
129 | | - i1, i2 = i1 + one(I), i2 + one(I) |
130 | | - x, ix = X |
131 | | - cmp_res = _lt_fn( |
132 | | - (by(x[ix[i1]]), ix[i1]), |
133 | | - (by(x[ix[i2]]), ix[i2]), lt, rev) |
134 | | - @inbounds if dir != cmp_res |
135 | | - ix[i1], ix[i2] = ix[i2], ix[i1] |
136 | | - end |
137 | | -end |
138 | | - |
139 | | -function cmp_small!( |
140 | | - swap::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev, |
141 | | -) where I |
142 | | - i1, i2 = i1 + one(I), i2 + one(I) |
143 | | - @inbounds if dir != _lt_fn(by(swap[i1]), by(swap[i2]), lt, rev) |
144 | | - swap[i1], swap[i2] = swap[i2], swap[i1] |
145 | | - end |
146 | | -end |
147 | | - |
148 | | -function cmp_small!( |
149 | | - swap::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev, |
150 | | -) where I |
151 | | - i1, i2 = i1 + one(I), i2 + one(I) |
152 | | - x, ix = swap |
153 | | - cmp_res = _lt_fn( |
154 | | - (by(x[i1]), ix[i1]), |
155 | | - (by(x[i2]), ix[i2]), lt, rev) |
156 | | - @inbounds if dir != cmp_res |
157 | | - x[i1], x[i2] = x[i2], x[i1] |
158 | | - ix[i1], ix[i2] = ix[i2], ix[i1] |
159 | | - end |
160 | | -end |
161 | | - |
162 | | -@inline function _lt_fn(a::T, b::T, lt, rev::Val{R}) where {T, R} |
163 | | - if R |
164 | | - lt(b, a) |
165 | | - else |
166 | | - lt(a, b) |
167 | | - end |
168 | | -end |
169 | | - |
170 | | -@inline function _lt_fn(a::Tuple{T, J}, b::Tuple{T, J}, lt, rev::Val{R}) where {T, J, R} |
171 | | - if R |
172 | | - if a[1] == b[1] |
173 | | - return a[2] < b[2] # Compare indices. |
174 | | - else |
175 | | - return lt(b[1], a[1]) |
176 | | - end |
177 | | - else |
178 | | - return lt(a, b) |
179 | | - end |
180 | | -end |
181 | | - |
182 | | -function init_shmem(x::AbstractArray{T}, idx, in_range::Bool, offset=0) where T |
183 | | - swap = @ROCDynamicLocalArray( |
184 | | - T, (workgroupDim().x, workgroupDim().y), false, offset) |
185 | | - if in_range |
186 | | - @inbounds swap[workitemIdx().x, workitemIdx().y] = x[idx + 0x1] |
187 | | - end |
188 | | - sync_workgroup() |
189 | | - @inbounds @view(swap[:, workitemIdx().y]) |
190 | | -end |
191 | | - |
192 | | -function init_shmem( |
193 | | - X::Tuple{AbstractArray{T}, AbstractArray{J}}, idx, in_range::Bool, |
194 | | -) where {T, J} |
195 | | - x, ix = X |
196 | | - idx_swap = init_shmem(ix, idx, in_range) |
197 | | - offset = (workgroupDim().x * workgroupDim().y) * sizeof(J) |
198 | | - swap = init_shmem(x, idx_swap[workitemIdx().x] - 0x1, in_range, offset) |
199 | | - swap, idx_swap |
200 | | -end |
201 | | - |
202 | | -""" |
203 | | -Copy `swap` back into global memory `x`. |
204 | | -""" |
205 | | -function finalize_shmem!( |
206 | | - x::AbstractArray, swap::AbstractArray, idx, in_range::Bool, |
207 | | -) |
208 | | - if in_range |
209 | | - @inbounds x[idx + 0x1] = swap[workitemIdx().x] |
210 | | - end |
211 | | -end |
212 | | - |
213 | | -function finalize_shmem!(X::Tuple, swap::Tuple, idx, in_range::Bool) |
214 | | - x, ix = X |
215 | | - x_swap, idx_swap = swap |
216 | | - finalize_shmem!(ix, idx_swap, idx, in_range) |
217 | | -end |
218 | | - |
219 | | -function get_range_part1(n::I, index::I, k::I) where I |
220 | | - lo = zero(I) |
221 | | - dir = true |
222 | | - for iter in one(I):(k - one(I)) |
223 | | - if n ≤ one(I) |
224 | | - return -one(I), -one(I), false |
225 | | - end |
226 | | - |
227 | | - if index < lo + n ÷ 0x2 |
228 | | - n = n ÷ 0x2 |
229 | | - dir = !dir |
230 | | - else |
231 | | - lo = lo + n ÷ 0x2 |
232 | | - n = n - n ÷ 0x2 |
233 | | - end |
234 | | - end |
235 | | - lo, n, dir |
236 | | -end |
237 | | - |
238 | | -function get_range_part2(lo::I, n::I, index::I, j::I) where I |
239 | | - for iter in one(I):(j - one(I)) |
240 | | - lo, n = bisect_range(index, lo, n) |
241 | | - end |
242 | | - lo, n |
243 | | -end |
244 | | - |
245 | | -# Determine parameters for swapping. |
246 | | -function get_range(n, idx, k, j) |
247 | | - lo, n, dir = get_range_part1(n, idx, k) |
248 | | - lo, n = get_range_part2(lo, n, idx, j) |
249 | | - lo, n, dir |
250 | | -end |
251 | | - |
252 | | -function block_range(n::I, bidx::I, k::I, j::I) where I |
253 | | - lo = zero(I) |
254 | | - dir = true |
255 | | - tmp = bidx * I(2) |
256 | | - |
257 | | - # Part 1. |
258 | | - for i in one(I):(k - one(I)) |
259 | | - tmp ÷= I(2) |
260 | | - n ≤ one(I) && return -one(I), -one(I), false |
261 | | - |
262 | | - if tmp % I(2) == zero(I) |
263 | | - n ÷= I(2) |
264 | | - dir = !dir |
265 | | - else |
266 | | - lo += n ÷ I(2) |
267 | | - n -= n ÷ I(2) |
268 | | - end |
269 | | - end |
270 | | - |
271 | | - # Part 2. |
272 | | - for i in one(I):(j - one(I)) |
273 | | - tmp ÷= I(2) |
274 | | - n ≤ one(I) && return -one(I), -one(I), false |
275 | | - |
276 | | - m = gp2lt(n) |
277 | | - if tmp % I(2) == zero(I) |
278 | | - n = m |
279 | | - else |
280 | | - lo += m |
281 | | - n -= m |
282 | | - end |
283 | | - end |
284 | | - |
285 | | - (zero(I) ≤ n ≤ one(I)) && return -one(I), -one(I), false |
286 | | - lo, n, dir |
287 | | -end |
288 | | - |
289 | | -@inline function gp2lt(x::I)::I where I |
290 | | - x -= one(I) |
291 | | - x |= x >> 0x1 |
292 | | - x |= x >> 0x2 |
293 | | - x |= x >> 0x4 |
294 | | - x |= x >> 0x8 |
295 | | - x |= x >> 0x102 |
296 | | - x ⊻ (x >> 0x1) |
297 | | -end |
| 1 | +Base.sort!(x::AnyROCArray; kwargs...) = (AK.sort!(x; kwargs...); return x) |
| 2 | +Base.sortperm!(ix::AnyROCArray, x::AnyROCArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix) |
| 3 | +Base.sortperm(x::AnyROCArray; kwargs...) = sortperm!(ROCArray(1:length(x)), x; kwargs...) |
0 commit comments