Skip to content

Commit 0905cc8

Browse files
authored
Use AcceleratedKernels for sorting (#688)
1 parent e700433 commit 0905cc8

File tree

4 files changed

+7
-300
lines changed

4 files changed

+7
-300
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "AMDGPU"
22
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
33
authors = ["Julian P Samaroo <[email protected]>", "Valentin Churavy <[email protected]>", "Anton Smirnov <[email protected]>"]
4-
version = "1.0.4"
4+
version = "1.0.5"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
8+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
89
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
1011
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -34,6 +35,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
3435

3536
[compat]
3637
AbstractFFTs = "1.0"
38+
AcceleratedKernels = "0.1.0"
3739
Adapt = "4"
3840
Atomix = "0.1"
3941
CEnum = "0.4, 0.5"

src/AMDGPU.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LLVM, LLVM.Interop
99
using Preferences
1010
using Printf
1111

12+
import AcceleratedKernels as AK
1213
import UnsafeAtomics
1314
import UnsafeAtomicsLLVM
1415
import Atomix

src/kernels/sorting.jl

Lines changed: 3 additions & 297 deletions
Original file line numberDiff line numberDiff line change
@@ -1,297 +1,3 @@
1-
# Ported from CUDA.jl.
2-
# Originally developed by @xaellison (Alex Ellison).
3-
4-
Base.sort!(x::AnyROCArray; kwargs...) = bitonic_sort!(x; kwargs...)
5-
6-
function Base.sortperm!(
7-
ix::AnyROCArray, x::AnyROCArray;
8-
initialized::Bool = false, kwargs...,
9-
)
10-
axes(ix) == axes(x) || throw(ArgumentError(
11-
"Index array must have the same size as the source array, instead: " *
12-
"$(size(ix)) vs $(size(x))."))
13-
14-
initialized || (ix .= LinearIndices(x);)
15-
bitonic_sort!((x, ix); kwargs...)
16-
return ix
17-
end
18-
19-
function Base.sortperm(x::AnyROCArray; kwargs...)
20-
sortperm!(ROCArray(1:length(x)), x; initialized=true, kwargs...)
21-
end
22-
23-
# TODO dims
24-
function bitonic_sort!(X; lt = isless, by = identity, rev::Bool = false)
25-
_shmem(x::Tuple, groupsize) = prod(groupsize) * sum(sizeof.(eltype.(x)))
26-
_shmem(x::AbstractArray, groupsize) = prod(groupsize) * sizeof(eltype(x))
27-
28-
len_x = typeof(X) <: Tuple ? length(X[1]) : length(X)
29-
I = len_x typemax(Int32) ? Int32 : Int64
30-
threads = min(256, prevpow(2, len_x))
31-
32-
# Compile kernels.
33-
ker_1 = @roc launch=false cmp_small_kern!(
34-
X, I(len_x), one(I), one(I), one(I), by, lt, Val(rev))
35-
ker_2 = @roc launch=false cmp_ker!(
36-
X, I(len_x), one(I), one(I), by, lt, Val(rev))
37-
38-
# Cutoff for when to use `ker_1` vs `ker_2`.
39-
log_threads = Int(log2(threads))
40-
41-
k₀ = ceil(Int, log2(len_x))
42-
for k in k₀:-1:1
43-
j_end = k₀ - k + 1
44-
for j in 1:j_end
45-
if k₀ - k - j + 2 log_threads
46-
pseudo_block_len = 1 << abs(j_end + 1 - j)
47-
n_pseudo_blocks = nextpow(2, len_x) ÷ pseudo_block_len
48-
pseudo_blocks_per_block = threads ÷ pseudo_block_len
49-
50-
gridsize = max(1, n_pseudo_blocks ÷ pseudo_blocks_per_block)
51-
groupsize = (pseudo_block_len, threads ÷ pseudo_block_len)
52-
ker_1(
53-
X, I(len_x), I(k), I(j), I(j_end), by, lt, Val(rev);
54-
gridsize, groupsize, shmem=_shmem(X, groupsize))
55-
else
56-
gridsize = cld(len_x, threads)
57-
ker_2(
58-
X, I(len_x), I(k), I(j), by, lt, Val(rev);
59-
gridsize, groupsize=threads)
60-
end
61-
end
62-
end
63-
return X
64-
end
65-
66-
function cmp_ker!(x, lenₓ::I, k::I, j::I, by, lt, rev) where I
67-
idx::I = workgroupDim().x * (workgroupIdx().x - 0x1) + workitemIdx().x - 0x1
68-
lo, n, dir = get_range(lenₓ, idx, k, j)
69-
70-
if !(lo < 0x0 || n < 0x0) && !(idx lenₓ)
71-
m = gp2lt(n)
72-
if lo idx < lo + n - m
73-
i1, i2 = idx, idx + m
74-
cmp!(x, i1, i2, dir, by, lt, rev)
75-
end
76-
end
77-
return
78-
end
79-
80-
function cmp_small_kern!(x, lenₓ::I, k::I, j₀::I, jₑ::I, by, lt, rev) where I
81-
bidx::I = (workgroupIdx().x - 0x1) * workgroupDim().y + workitemIdx().y - 0x1
82-
_lo, _n, dir = block_range(lenₓ, bidx, k, j₀)
83-
84-
idx = _lo + I(workitemIdx().x) - 0x1
85-
in_range = workitemIdx().x _n && _lo 0x0
86-
swap = init_shmem(x, idx, in_range)
87-
88-
lo, n = _lo, _n
89-
for j in j₀:jₑ
90-
if in_range && !(lo < 0x0 || n < 0x0)
91-
m = gp2lt(n)
92-
if lo idx < lo + n - m
93-
i1, i2 = idx - _lo, idx - _lo + m
94-
cmp_small!(swap, i1, i2, dir, by, lt, rev)
95-
end
96-
end
97-
lo, n = bisect_range(idx, lo, n)
98-
sync_workgroup()
99-
end
100-
finalize_shmem!(x, swap, idx, in_range)
101-
return
102-
end
103-
104-
function bisect_range(idx::I, lo::I, n::I) where I
105-
n 0x1 && return -one(I), -one(I)
106-
107-
m = gp2lt(n)
108-
if idx < lo + m
109-
n = m
110-
else
111-
lo += m
112-
n -= m
113-
end
114-
lo, n
115-
end
116-
117-
function cmp!(
118-
x::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev,
119-
) where I
120-
i1, i2 = i1 + one(I), i2 + one(I)
121-
@inbounds if dir != _lt_fn(by(x[i1]), by(x[i2]), lt, rev)
122-
x[i1], x[i2] = x[i2], x[i1]
123-
end
124-
end
125-
126-
function cmp!(
127-
X::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev,
128-
) where I
129-
i1, i2 = i1 + one(I), i2 + one(I)
130-
x, ix = X
131-
cmp_res = _lt_fn(
132-
(by(x[ix[i1]]), ix[i1]),
133-
(by(x[ix[i2]]), ix[i2]), lt, rev)
134-
@inbounds if dir != cmp_res
135-
ix[i1], ix[i2] = ix[i2], ix[i1]
136-
end
137-
end
138-
139-
function cmp_small!(
140-
swap::AbstractArray, i1::I, i2::I, dir::Bool, by, lt, rev,
141-
) where I
142-
i1, i2 = i1 + one(I), i2 + one(I)
143-
@inbounds if dir != _lt_fn(by(swap[i1]), by(swap[i2]), lt, rev)
144-
swap[i1], swap[i2] = swap[i2], swap[i1]
145-
end
146-
end
147-
148-
function cmp_small!(
149-
swap::Tuple, i1::I, i2::I, dir::Bool, by, lt, rev,
150-
) where I
151-
i1, i2 = i1 + one(I), i2 + one(I)
152-
x, ix = swap
153-
cmp_res = _lt_fn(
154-
(by(x[i1]), ix[i1]),
155-
(by(x[i2]), ix[i2]), lt, rev)
156-
@inbounds if dir != cmp_res
157-
x[i1], x[i2] = x[i2], x[i1]
158-
ix[i1], ix[i2] = ix[i2], ix[i1]
159-
end
160-
end
161-
162-
@inline function _lt_fn(a::T, b::T, lt, rev::Val{R}) where {T, R}
163-
if R
164-
lt(b, a)
165-
else
166-
lt(a, b)
167-
end
168-
end
169-
170-
@inline function _lt_fn(a::Tuple{T, J}, b::Tuple{T, J}, lt, rev::Val{R}) where {T, J, R}
171-
if R
172-
if a[1] == b[1]
173-
return a[2] < b[2] # Compare indices.
174-
else
175-
return lt(b[1], a[1])
176-
end
177-
else
178-
return lt(a, b)
179-
end
180-
end
181-
182-
function init_shmem(x::AbstractArray{T}, idx, in_range::Bool, offset=0) where T
183-
swap = @ROCDynamicLocalArray(
184-
T, (workgroupDim().x, workgroupDim().y), false, offset)
185-
if in_range
186-
@inbounds swap[workitemIdx().x, workitemIdx().y] = x[idx + 0x1]
187-
end
188-
sync_workgroup()
189-
@inbounds @view(swap[:, workitemIdx().y])
190-
end
191-
192-
function init_shmem(
193-
X::Tuple{AbstractArray{T}, AbstractArray{J}}, idx, in_range::Bool,
194-
) where {T, J}
195-
x, ix = X
196-
idx_swap = init_shmem(ix, idx, in_range)
197-
offset = (workgroupDim().x * workgroupDim().y) * sizeof(J)
198-
swap = init_shmem(x, idx_swap[workitemIdx().x] - 0x1, in_range, offset)
199-
swap, idx_swap
200-
end
201-
202-
"""
203-
Copy `swap` back into global memory `x`.
204-
"""
205-
function finalize_shmem!(
206-
x::AbstractArray, swap::AbstractArray, idx, in_range::Bool,
207-
)
208-
if in_range
209-
@inbounds x[idx + 0x1] = swap[workitemIdx().x]
210-
end
211-
end
212-
213-
function finalize_shmem!(X::Tuple, swap::Tuple, idx, in_range::Bool)
214-
x, ix = X
215-
x_swap, idx_swap = swap
216-
finalize_shmem!(ix, idx_swap, idx, in_range)
217-
end
218-
219-
function get_range_part1(n::I, index::I, k::I) where I
220-
lo = zero(I)
221-
dir = true
222-
for iter in one(I):(k - one(I))
223-
if n one(I)
224-
return -one(I), -one(I), false
225-
end
226-
227-
if index < lo + n ÷ 0x2
228-
n = n ÷ 0x2
229-
dir = !dir
230-
else
231-
lo = lo + n ÷ 0x2
232-
n = n - n ÷ 0x2
233-
end
234-
end
235-
lo, n, dir
236-
end
237-
238-
function get_range_part2(lo::I, n::I, index::I, j::I) where I
239-
for iter in one(I):(j - one(I))
240-
lo, n = bisect_range(index, lo, n)
241-
end
242-
lo, n
243-
end
244-
245-
# Determine parameters for swapping.
246-
function get_range(n, idx, k, j)
247-
lo, n, dir = get_range_part1(n, idx, k)
248-
lo, n = get_range_part2(lo, n, idx, j)
249-
lo, n, dir
250-
end
251-
252-
function block_range(n::I, bidx::I, k::I, j::I) where I
253-
lo = zero(I)
254-
dir = true
255-
tmp = bidx * I(2)
256-
257-
# Part 1.
258-
for i in one(I):(k - one(I))
259-
tmp ÷= I(2)
260-
n one(I) && return -one(I), -one(I), false
261-
262-
if tmp % I(2) == zero(I)
263-
n ÷= I(2)
264-
dir = !dir
265-
else
266-
lo += n ÷ I(2)
267-
n -= n ÷ I(2)
268-
end
269-
end
270-
271-
# Part 2.
272-
for i in one(I):(j - one(I))
273-
tmp ÷= I(2)
274-
n one(I) && return -one(I), -one(I), false
275-
276-
m = gp2lt(n)
277-
if tmp % I(2) == zero(I)
278-
n = m
279-
else
280-
lo += m
281-
n -= m
282-
end
283-
end
284-
285-
(zero(I) n one(I)) && return -one(I), -one(I), false
286-
lo, n, dir
287-
end
288-
289-
@inline function gp2lt(x::I)::I where I
290-
x -= one(I)
291-
x |= x >> 0x1
292-
x |= x >> 0x2
293-
x |= x >> 0x4
294-
x |= x >> 0x8
295-
x |= x >> 0x102
296-
x (x >> 0x1)
297-
end
1+
Base.sort!(x::AnyROCArray; kwargs...) = (AK.sort!(x; kwargs...); return x)
2+
Base.sortperm!(ix::AnyROCArray, x::AnyROCArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix)
3+
Base.sortperm(x::AnyROCArray; kwargs...) = sortperm!(ROCArray(1:length(x)), x; kwargs...)

test/core_tests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,8 @@ end
8484
end
8585

8686
include("codegen/codegen.jl")
87-
8887
include("rocarray/base.jl")
8988
include("rocarray/broadcast.jl")
90-
9189
include("tls.jl")
9290

9391
end

0 commit comments

Comments
 (0)