Skip to content

Commit 3e814ca

Browse files
authored
Merge pull request #15 from JuliaGPU/metal-accumulate-prefix
Added new `ScanPrefix` accumulate algorithm
2 parents 0d275a8 + a8e33c5 commit 3e814ca

File tree

10 files changed

+216
-22
lines changed

10 files changed

+216
-22
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
1414
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
1515

1616
[weakdeps]
17+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1718
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
1819

1920
[extensions]
21+
PlatformDependentMetalExt = "Metal"
2022
PlatformDependentoneAPIExt = "oneAPI"
2123

2224
[compat]
@@ -25,6 +27,7 @@ DocStringExtensions = "0.9"
2527
GPUArraysCore = "0.1, 0.2"
2628
KernelAbstractions = "0.9"
2729
Markdown = "1"
30+
Metal = "1.4.2"
2831
OhMyThreads = "0.7"
2932
Polyester = "0.7"
3033
Unrolled = "0.1"

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ Julia v1.11
139139

140140
[Metal](https://github.com/JuliaGPU/Metal.jl)
141141

142-
[Known Issue with `accumulate` Only](https://github.com/JuliaGPU/AcceleratedKernels.jl/issues/10)
143-
144142
</td>
145143
<td>
146144

ext/PlatformDependentMetalExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module PlatformDependentMetalExt
2+
3+
4+
using Metal
5+
import AcceleratedKernels as AK
6+
7+
8+
# On Metal use the ScanPrefixes accumulation algorithm by default as the DecoupledLookback algorithm
9+
# cannot be supported due to Metal's weaker memory consistency guarantees.
10+
function AK.accumulate!(
11+
op, v::AbstractArray, backend::MetalBackend;
12+
init,
13+
inclusive::Bool=true,
14+
15+
# Algorithm choice
16+
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
17+
18+
# GPU settings
19+
block_size::Int=1024,
20+
temp::Union{Nothing, AbstractArray}=nothing,
21+
temp_flags::Union{Nothing, AbstractArray}=nothing,
22+
)
23+
AK._accumulate_impl!(
24+
op, v, backend,
25+
init=init, inclusive=inclusive,
26+
alg=alg,
27+
block_size=block_size, temp=temp, temp_flags=temp_flags,
28+
)
29+
end
30+
31+
32+
end # module PlatformDependentMetalExt

prototype/accumulate_benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Random.seed!(0)
77

88

99
function akacc(v)
10-
va = AK.accumulate(+, v, init=zero(eltype(v)), block_size=512)
10+
va = AK.accumulate(+, v, init=zero(eltype(v)), block_size=1024)
1111
Metal.synchronize()
1212
va
1313
end

prototype/accumulate_test_metal.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
using Random
3+
using BenchmarkTools
4+
using Profile
5+
using PProf
6+
7+
using KernelAbstractions
8+
using Metal
9+
10+
import AcceleratedKernels as AK
11+
12+
13+
Random.seed!(0)
14+
15+
16+
v = Metal.ones(Int32, 100)
17+
18+
v2 = AK.accumulate!(+, copy(v), init=zero(eltype(v)), block_size=1024)
19+
20+
@assert Array(v2) == cumsum(Array(v))
21+
22+
v2

src/accumulate/accumulate.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Available accumulation algorithms
2+
abstract type AccumulateAlgorithm end
3+
struct DecoupledLookback <: AccumulateAlgorithm end
4+
struct ScanPrefixes <: AccumulateAlgorithm end
5+
6+
7+
# Implementations, then interfaces
18
include("accumulate_1d.jl")
29

310

@@ -7,6 +14,10 @@ include("accumulate_1d.jl")
714
init,
815
inclusive::Bool=true,
916
17+
# Algorithm choice
18+
alg::AccumulateAlgorithm=DecoupledLookback(),
19+
20+
# GPU settings
1021
block_size::Int=256,
1122
temp::Union{Nothing, AbstractArray}=nothing,
1223
temp_flags::Union{Nothing, AbstractArray}=nothing,
@@ -22,13 +33,20 @@ element is included in the accumulation (or not).
2233
The `block_size` should be a power of 2 and greater than 0. The temporaries `temp` and `temp_flags`
2334
should both have at least
2435
`(length(v) + 2 * block_size - 1) ÷ (2 * block_size)` elements; `eltype(v) === eltype(temp)`; the
25-
elements in `temp_flags` can be any integers, but `Int8` is used by default to reduce memory usage.
36+
elements in `temp_flags` can be any integers, but `Int8` is used by default to reduce memory usage.
37+
38+
The `alg` can be one of the following:
39+
- `DecoupledLookback()`: the default algorithm, using opportunistic lookback to reuse earlier
40+
blocks' results; requires device-level memory consistency guarantees, which Apple Metal does not
41+
provide.
42+
- `ScanPrefixes()`: a simpler algorithm that scans the prefixes of each block, with no lookback;
43+
`temp_flags` is not used in this case.
2644
2745
# Platform-Specific Notes
28-
Currently, Apple Metal GPUs do not have strong enough memory consistency guarantees to support the
29-
industry-standard "decoupled lookback" algorithm for prefix sums - which means it currently may,
30-
for very large arrays, produce incorrect results ~0.38% of the time. We are currently working on an
31-
alternative algorithm without lookback ([issue](https://github.com/JuliaGPU/AcceleratedKernels.jl/issues/10)).
46+
On Metal, the `alg=ScanPrefixes()` algorithm is used by default, as Apple Metal GPUs do not have
47+
strong enough memory consistency guarantees for the `DecoupledLookback()` algorithm - which
48+
produces incorrect results about 0.38% of the time. Also, `block_size=1024` is used here by
49+
default to reduce the number of coupled lookbacks.
3250
3351
The CPU implementation currently defers to the single-threaded Base.accumulate!; we are waiting on a
3452
multithreaded implementation in OhMyThreads.jl ([issue](https://github.com/JuliaFolds2/OhMyThreads.jl/issues/129)).
@@ -41,20 +59,28 @@ using oneAPI
4159
4260
v = oneAPI.ones(Int32, 100_000)
4361
AK.accumulate!(+, v, init=0)
62+
63+
# Use a different algorithm
64+
AK.accumulate!(+, v, alg=AK.ScanPrefixes())
4465
```
4566
"""
4667
function accumulate!(
4768
op, v::AbstractArray, backend::Backend=get_backend(v);
4869
init,
4970
inclusive::Bool=true,
5071

72+
# Algorithm choice
73+
alg::AccumulateAlgorithm=DecoupledLookback(),
74+
75+
# GPU settings
5176
block_size::Int=256,
5277
temp::Union{Nothing, AbstractArray}=nothing,
5378
temp_flags::Union{Nothing, AbstractArray}=nothing,
5479
)
5580
_accumulate_impl!(
5681
op, v, backend,
5782
init=init, inclusive=inclusive,
83+
alg=alg,
5884
block_size=block_size, temp=temp, temp_flags=temp_flags,
5985
)
6086
end
@@ -65,13 +91,16 @@ function _accumulate_impl!(
6591
init,
6692
inclusive::Bool=true,
6793

94+
alg::AccumulateAlgorithm=DecoupledLookback(),
95+
96+
# GPU settings
6897
block_size::Int=256,
6998
temp::Union{Nothing, AbstractArray}=nothing,
7099
temp_flags::Union{Nothing, AbstractArray}=nothing,
71100
)
72101
if backend isa GPU
73102
accumulate_1d!(
74-
op, v, backend,
103+
op, v, backend, alg,
75104
init=init, inclusive=inclusive,
76105
block_size=block_size, temp=temp, temp_flags=temp_flags,
77106
)

src/accumulate/accumulate_1d.jl

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,16 @@ end
126126

127127
# Write this block's final prefix to global array and set flag to "block prefix computed"
128128
if bi == 0x2 * block_size - 0x1
129-
prefixes[iblock + 0x1] = temp[bi + bank_offset_b + 0x1]
130-
flags[iblock + 0x1] = ACC_FLAG_P
129+
130+
# Known at compile-time; used in the first pass of the ScanPrefixes algorithm
131+
if !isnothing(prefixes)
132+
prefixes[iblock + 0x1] = temp[bi + bank_offset_b + 0x1]
133+
end
134+
135+
# Known at compile-time; used only in the DecoupledLookback algorithm
136+
if !isnothing(flags)
137+
flags[iblock + 0x1] = ACC_FLAG_P
138+
end
131139
end
132140

133141
if block_offset + ai < len
@@ -192,8 +200,52 @@ end
192200
end
193201

194202

203+
@kernel cpu=false inbounds=true function _accumulate_previous_coupled_preblocks!(op, v, prefixes)
204+
205+
# No decoupled lookback
206+
len = length(v)
207+
block_size = @groupsize()[1]
208+
209+
# NOTE: for many index calculations in this library, computation using zero-indexing leads to
210+
# fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
211+
# indexing). Internal calculations will be done using zero indexing except when actually
212+
# accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.
213+
214+
# Group (block) and local (thread) indices
215+
iblock = @index(Group, Linear) - 0x1 + 0x1 # Skipping first block
216+
ithread = @index(Local, Linear) - 0x1
217+
block_offset = iblock * block_size * 0x2 # Processing two elements per thread
218+
219+
# Each block looks back to find running prefix sum
220+
running_prefix = prefixes[iblock - 0x1 + 0x1]
221+
222+
# The prefixes were pre-accumulated, which means (for block_size=N):
223+
# - If there were N or fewer prefixes (so fewer than N*N elements in v to begin with), the
224+
# prefixes were fully accumulated and we can use them directly.
225+
# - If there were more than N prefixes, each chunk of N prefixes was accumulated, but not
226+
# along the chunks. We need to accumulate the prefixes of the previous chunks into
227+
# running_prefix.
228+
num_preblocks = (iblock - 0x1) ÷ (block_size * 0x2)
229+
for i in 0x1:num_preblocks
230+
running_prefix = op(running_prefix, prefixes[i * block_size * 0x2])
231+
end
232+
233+
# Now we have aggregate prefix of all previous blocks, add it to all our elements
234+
ai = ithread
235+
if block_offset + ai < len
236+
v[block_offset + ai + 0x1] = op(running_prefix, v[block_offset + ai + 0x1])
237+
end
238+
239+
bi = ithread + block_size
240+
if block_offset + bi < len
241+
v[block_offset + bi + 0x1] = op(running_prefix, v[block_offset + bi + 0x1])
242+
end
243+
end
244+
245+
246+
# DecoupledLookback algorithm
195247
function accumulate_1d!(
196-
op, v::AbstractArray, backend::GPU;
248+
op, v::AbstractArray, backend::GPU, ::DecoupledLookback;
197249
init,
198250
inclusive::Bool=true,
199251

@@ -242,3 +294,56 @@ function accumulate_1d!(
242294

243295
return v
244296
end
297+
298+
299+
# ScanPrefixes algorithm
300+
function accumulate_1d!(
301+
op, v::AbstractArray, backend::GPU, ::ScanPrefixes;
302+
init,
303+
inclusive::Bool=true,
304+
305+
block_size::Int=256,
306+
temp::Union{Nothing, AbstractArray}=nothing,
307+
temp_flags::Union{Nothing, AbstractArray}=nothing,
308+
)
309+
# Correctness checks
310+
@argcheck block_size > 0
311+
@argcheck ispow2(block_size)
312+
313+
# Nothing to accumulate
314+
if length(v) == 0
315+
return v
316+
end
317+
318+
# Each thread will process two elements
319+
elems_per_block = block_size * 2
320+
num_blocks = (length(v) + elems_per_block - 1) ÷ elems_per_block
321+
322+
if isnothing(temp)
323+
prefixes = similar(v, eltype(v), num_blocks)
324+
else
325+
@argcheck eltype(temp) === eltype(v)
326+
@argcheck length(temp) >= num_blocks
327+
prefixes = temp
328+
end
329+
330+
kernel1! = _accumulate_block!(backend, block_size)
331+
kernel1!(op, v, init, inclusive, nothing, prefixes,
332+
ndrange=num_blocks * block_size)
333+
334+
if num_blocks > 1
335+
336+
# Accumulate prefixes of all blocks
337+
num_blocks_prefixes = (length(prefixes) + elems_per_block - 1) ÷ elems_per_block
338+
kernel1!(op, prefixes, init, true, nothing, nothing,
339+
ndrange=num_blocks_prefixes * block_size)
340+
341+
# Prefixes are pre-accumulated (completely accumulated if num_blocks_prefixes == 1, or
342+
# partially, which we will account for in the coupled lookback)
343+
kernel2! = _accumulate_previous_coupled_preblocks!(backend, block_size)
344+
kernel2!(op, v, prefixes,
345+
ndrange=(num_blocks - 1) * block_size)
346+
end
347+
348+
return v
349+
end

src/reduce/mapreduce_nd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
iblock = @index(Group, Linear) - 0x1
2424
ithread = @index(Local, Linear) - 0x1
2525

26-
tid = ithread + iblock * N
27-
2826
# Each thread handles one output element
2927
tid = ithread + iblock * N
3028
if tid < output_size

src/reduce/reduce_nd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
iblock = @index(Group, Linear) - 0x1
2424
ithread = @index(Local, Linear) - 0x1
2525

26-
tid = ithread + iblock * N
27-
2826
# Each thread handles one output element
2927
tid = ithread + iblock * N
3028
if tid < output_size

test/runtests.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,34 @@ import Pkg
1111
if "--CUDA" in ARGS
1212
Pkg.add("CUDA")
1313
using CUDA
14-
display(CUDA.versioninfo())
14+
CUDA.versioninfo()
1515
const backend = CUDABackend()
1616
elseif "--oneAPI" in ARGS
1717
Pkg.add("oneAPI")
1818
using oneAPI
19-
display(oneAPI.versioninfo())
19+
oneAPI.versioninfo()
2020
const backend = oneAPIBackend()
2121
elseif "--AMDGPU" in ARGS
2222
Pkg.add("AMDGPU")
2323
using AMDGPU
24-
display(AMDGPU.versioninfo())
24+
AMDGPU.versioninfo()
2525
const backend = ROCBackend()
2626
elseif "--Metal" in ARGS
2727
Pkg.add("Metal")
2828
using Metal
29-
display(Metal.versioninfo())
29+
Metal.versioninfo()
3030
const backend = MetalBackend()
3131
elseif "--OpenCL" in ARGS
3232
Pkg.add(name="OpenCL", rev="master")
3333
Pkg.add("pocl_jll")
3434
using pocl_jll
3535
using OpenCL
36-
display(OpenCL.versioninfo())
36+
OpenCL.versioninfo()
3737
const backend = OpenCLBackend()
3838
elseif !@isdefined(backend)
3939
# Otherwise do CPU tests
4040
using InteractiveUtils
41-
display(InteractiveUtils.versioninfo())
41+
InteractiveUtils.versioninfo()
4242
const backend = CPU()
4343
end
4444

@@ -1059,6 +1059,15 @@ end
10591059
@test all(Array(y) .== accumulate(+, Array(x)))
10601060
end
10611061

1062+
# Stress-testing small block sizes -> many blocks
1063+
for _ in 1:100
1064+
num_elems = rand(1:100_000)
1065+
x = array_from_host(rand(1:1000, num_elems), Int32)
1066+
y = copy(x)
1067+
AK.accumulate!(+, y; init=0, block_size=16)
1068+
@test all(Array(y) .== accumulate(+, Array(x)))
1069+
end
1070+
10621071
# Testing different settings
10631072
AK.accumulate!(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false,
10641073
block_size=128,

0 commit comments

Comments
 (0)