Skip to content

Commit e3a2eb1

Browse files
authored
Merge pull request #26 from JuliaGPU/unsafe_indices
Added unsafe_indices to kernels, with Local/Group indices changes where needed
2 parents c01e7c2 + b1280f5 commit e3a2eb1

File tree

9 files changed

+136
-108
lines changed

9 files changed

+136
-108
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AcceleratedKernels"
22
uuid = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
33
authors = ["Andrei-Leonard Nicusan <leonard@evophase.co.uk> and contributors"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -23,7 +23,7 @@ AcceleratedKernelsoneAPIExt = "oneAPI"
2323
[compat]
2424
ArgCheck = "2"
2525
GPUArrays = "10, 11"
26-
KernelAbstractions = "0.9"
26+
KernelAbstractions = "0.9.34"
2727
Markdown = "1"
2828
Metal = "1"
2929
OhMyThreads = "0.7"

prototype/reduce_nd_test.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ end
2424
# Make array with highly unequal per-axis sizes
2525
s = MtlArray(rand(Int32(1):Int32(100), 10, 1_000_000))
2626
AK.reduce(+, s, init=zero(eltype(s)))
27-
ret
2827

2928
# Correctness
3029
@assert sum_base(s, dims=1) == sum_ak(s, dims=1)
@@ -34,11 +33,11 @@ ret
3433
# Benchmarks
3534
println("\nReduction over small axis - AK vs Base")
3635
display(@benchmark sum_ak($s, dims=1))
37-
display(@benchmark sum_base($s, dims=1))
36+
# display(@benchmark sum_base($s, dims=1))
3837

3938
println("\nReduction over long axis - AK vs Base")
4039
display(@benchmark sum_ak($s, dims=2))
41-
display(@benchmark sum_base($s, dims=2))
40+
# display(@benchmark sum_base($s, dims=2))
4241

4342

4443

src/accumulate/accumulate_1d.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ const ACC_FLAG_P::UInt8 = 1 # Only current block's prefix available
1212
end
1313

1414

15-
@kernel cpu=false inbounds=true function _accumulate_block!(op, v, init, neutral,
16-
inclusive,
17-
flags, prefixes) # one per block
18-
15+
@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_block!(
16+
op, v, init, neutral,
17+
inclusive,
18+
flags, prefixes, # one per block
19+
)
1920
# NOTE: shmem_size MUST be greater than 2 * block_size
2021
# NOTE: block_size MUST be a power of 2
2122
len = length(v)
@@ -147,7 +148,9 @@ end
147148
end
148149

149150

150-
@kernel cpu=false inbounds=true function _accumulate_previous!(op, v, flags, @Const(prefixes))
151+
@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous!(
152+
op, v, flags, @Const(prefixes),
153+
)
151154

152155
len = length(v)
153156
block_size = @groupsize()[1]
@@ -200,8 +203,9 @@ end
200203
end
201204

202205

203-
@kernel cpu=false inbounds=true function _accumulate_previous_coupled_preblocks!(op, v, prefixes)
204-
206+
@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous_coupled_preblocks!(
207+
op, v, prefixes,
208+
)
205209
# No decoupled lookback
206210
len = length(v)
207211
block_size = @groupsize()[1]

src/accumulate/accumulate_nd.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
@kernel inbounds=true cpu=false function _accumulate_nd_by_thread!(v, op, init, dims, inclusive)
2-
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_thread!(
2+
v, op, init, dims, inclusive,
3+
)
34
# One thread per outer dimension element, when there are more outer elements than in the
45
# reduced dim e.g. accumulate(+, rand(3, 1000), dims=1) => only 3 elements in the accumulated
56
# dim
@@ -57,8 +58,9 @@
5758
end
5859

5960

60-
@kernel inbounds=true cpu=false function _accumulate_nd_by_block!(v, op, init, neutral, dims, inclusive)
61-
61+
@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_block!(
62+
v, op, init, neutral, dims, inclusive,
63+
)
6264
# NOTE: shmem_size MUST be greater than 2 * block_size
6365
# NOTE: block_size MUST be a power of 2
6466

src/foreachindex.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
@kernel cpu=false inbounds=true function _forindices_global!(f, indices)
2-
i = @index(Global, Linear)
3-
f(indices[i])
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _forindices_global!(f, indices)
2+
3+
# Calculate global index
4+
N = @groupsize()[1]
5+
iblock = @index(Group, Linear)
6+
ithread = @index(Local, Linear)
7+
i = ithread + (iblock - 0x1) * N
8+
9+
if i <= length(indices)
10+
f(indices[i])
11+
end
412
end
513

614

@@ -13,7 +21,8 @@ function _forindices_gpu(
1321
)
1422
# GPU implementation
1523
@argcheck block_size > 0
16-
_forindices_global!(backend, block_size)(f, indices, ndrange=length(indices))
24+
blocks = (length(indices) + block_size - 1) ÷ block_size
25+
_forindices_global!(backend, block_size)(f, indices, ndrange=(block_size * blocks,))
1726
nothing
1827
end
1928

src/reduce/mapreduce_1d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@kernel inbounds=true cpu=false function _mapreduce_block!(@Const(src), dst, f, op, neutral)
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_block!(@Const(src), dst, f, op, neutral)
22

33
@uniform N = @groupsize()[1]
44
sdata = @localmem eltype(dst) (N,)

src/reduce/mapreduce_nd.jl

Lines changed: 91 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
@kernel inbounds=true cpu=false function _mapreduce_nd_by_thread!(@Const(src), dst, f, op, init, dims)
2-
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_thread!(
2+
@Const(src),
3+
dst,
4+
f,
5+
op,
6+
init,
7+
dims,
8+
)
39
# One thread per output element, when there are more outer elements than in the reduced dim
410
# e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim
511
src_sizes = size(src)
@@ -64,8 +70,15 @@
6470
end
6571

6672

67-
@kernel inbounds=true cpu=false function _mapreduce_nd_by_block!(@Const(src), dst, f, op, init, neutral, dims)
68-
73+
@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_block!(
74+
@Const(src),
75+
dst,
76+
f,
77+
op,
78+
init,
79+
neutral,
80+
dims,
81+
)
6982
# One block per output element, when there are more elements in the reduced dim than in outer
7083
# e.g. reduce(+, rand(3, 1000), dims=2) => only 3 elements in outer dimensions
7184
src_sizes = size(src)
@@ -90,86 +103,84 @@ end
90103
iblock = @index(Group, Linear) - 0x1
91104
ithread = @index(Local, Linear) - 0x1
92105

93-
# Each block handles one output element
94-
if iblock < output_size
95-
96-
# # Sometimes slightly faster method using additional memory with
97-
# # output_idx = @private typeof(iblock) (ndims,)
98-
# tmp = iblock
99-
# KernelAbstractions.Extras.@unroll for i in ndims:-1:1
100-
# output_idx[i] = tmp ÷ dst_strides[i]
101-
# tmp = tmp % dst_strides[i]
102-
# end
103-
# # Compute the base index in src (excluding the reduced axis)
104-
# input_base_idx = 0
105-
# KernelAbstractions.Extras.@unroll for i in 1:ndims
106-
# i == dims && continue
107-
# input_base_idx += output_idx[i] * src_strides[i]
108-
# end
109-
110-
# Compute the base index in src (excluding the reduced axis)
111-
input_base_idx = typeof(ithread)(0)
112-
tmp = iblock
113-
KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
114-
if i != dims
115-
input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
116-
end
117-
tmp = tmp % dst_strides[i]
106+
# Each block handles one output element - thus, iblock ∈ [0, output_size)
107+
108+
# # Sometimes slightly faster method using additional memory with
109+
# # output_idx = @private typeof(iblock) (ndims,)
110+
# tmp = iblock
111+
# KernelAbstractions.Extras.@unroll for i in ndims:-1:1
112+
# output_idx[i] = tmp ÷ dst_strides[i]
113+
# tmp = tmp % dst_strides[i]
114+
# end
115+
# # Compute the base index in src (excluding the reduced axis)
116+
# input_base_idx = 0
117+
# KernelAbstractions.Extras.@unroll for i in 1:ndims
118+
# i == dims && continue
119+
# input_base_idx += output_idx[i] * src_strides[i]
120+
# end
121+
122+
# Compute the base index in src (excluding the reduced axis)
123+
input_base_idx = typeof(ithread)(0)
124+
tmp = iblock
125+
KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
126+
if i != dims
127+
input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
118128
end
129+
tmp = tmp % dst_strides[i]
130+
end
119131

120-
# We have a block of threads to process the whole reduced dimension. First do pre-reduction
121-
# in strides of N
122-
partial = neutral
123-
i = ithread
124-
while i < reduce_size
125-
src_idx = input_base_idx + i * src_strides[dims]
126-
partial = op(partial, f(src[src_idx + 0x1]))
127-
i += N
128-
end
132+
# We have a block of threads to process the whole reduced dimension. First do pre-reduction
133+
# in strides of N
134+
partial = neutral
135+
i = ithread
136+
while i < reduce_size
137+
src_idx = input_base_idx + i * src_strides[dims]
138+
partial = op(partial, f(src[src_idx + 0x1]))
139+
i += N
140+
end
129141

130-
# Store partial result in shared memory; now we are down to a single block to reduce within
131-
sdata[ithread + 0x1] = partial
132-
@synchronize()
142+
# Store partial result in shared memory; now we are down to a single block to reduce within
143+
sdata[ithread + 0x1] = partial
144+
@synchronize()
133145

134-
if N >= 512u16
135-
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
136-
@synchronize()
137-
end
138-
if N >= 256u16
139-
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
140-
@synchronize()
141-
end
142-
if N >= 128u16
143-
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
144-
@synchronize()
145-
end
146-
if N >= 64u16
147-
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
148-
@synchronize()
149-
end
150-
if N >= 32u16
151-
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
152-
@synchronize()
153-
end
154-
if N >= 16u16
155-
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
156-
@synchronize()
157-
end
158-
if N >= 8u16
159-
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
160-
@synchronize()
161-
end
162-
if N >= 4u16
163-
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
164-
@synchronize()
165-
end
166-
if N >= 2u16
167-
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
168-
@synchronize()
169-
end
170-
if ithread == 0x0
171-
dst[iblock + 0x1] = op(init, sdata[0x1])
172-
end
146+
if N >= 512u16
147+
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
148+
@synchronize()
149+
end
150+
if N >= 256u16
151+
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
152+
@synchronize()
153+
end
154+
if N >= 128u16
155+
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
156+
@synchronize()
157+
end
158+
if N >= 64u16
159+
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
160+
@synchronize()
161+
end
162+
if N >= 32u16
163+
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
164+
@synchronize()
165+
end
166+
if N >= 16u16
167+
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
168+
@synchronize()
169+
end
170+
if N >= 8u16
171+
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
172+
@synchronize()
173+
end
174+
if N >= 4u16
175+
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
176+
@synchronize()
177+
end
178+
if N >= 2u16
179+
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
180+
@synchronize()
181+
end
182+
if ithread == 0x0
183+
dst[iblock + 0x1] = op(init, sdata[0x1])
173184
end
174185
end
175186

src/sort/merge_sort.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@kernel inbounds=true function _merge_sort_block!(vec, comp)
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_block!(vec, comp)
22

33
@uniform N = @groupsize()[1]
44
s_buf = @localmem eltype(vec) (N * 0x2,)
@@ -75,8 +75,9 @@
7575
end
7676

7777

78-
@kernel inbounds=true function _merge_sort_global!(@Const(vec_in), vec_out, comp, half_size_group)
79-
78+
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_global!(
79+
@Const(vec_in), vec_out, comp, half_size_group,
80+
)
8081
len = length(vec_in)
8182
N = @groupsize()[1]
8283

src/sort/merge_sort_by_key.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@kernel inbounds=true function _merge_sort_by_key_block!(keys, values, comp)
1+
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_block!(keys, values, comp)
22

33
@uniform N = @groupsize()[1]
44
s_keys = @localmem eltype(keys) (N * 0x2,)
@@ -97,9 +97,11 @@
9797
end
9898

9999

100-
@kernel inbounds=true function _merge_sort_by_key_global!(@Const(keys_in), keys_out,
101-
@Const(values_in), values_out,
102-
comp, half_size_group)
100+
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_global!(
101+
@Const(keys_in), keys_out,
102+
@Const(values_in), values_out,
103+
comp, half_size_group,
104+
)
103105

104106
len = length(keys_in)
105107
N = @groupsize()[1]

0 commit comments

Comments
 (0)