Skip to content

Commit 06c2594

Browse files
authored
Merge pull request #47 from christiangnrd/err
Fix for `accumulate` by block
2 parents 8dd1ce9 + 0abe1b3 commit 06c2594

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AcceleratedKernels"
22
uuid = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
33
authors = ["Andrei-Leonard Nicusan <[email protected]> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/accumulate/accumulate_nd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ end
227227
# We have a block of threads to accumulate along the dims axis; do it in chunks of
228228
# block_size and keep track of previous chunks' running prefix
229229
ichunk = typeof(iblock)(0)
230-
num_chunks = (length_dims + block_size - 0x1) ÷ block_size
230+
num_chunks = (length_dims + (0x2 * block_size) - 0x1) ÷ (0x2 * block_size)
231231
total = neutral
232232

233233
if ithread == 0x0
@@ -326,7 +326,7 @@ end
326326

327327
# ...and accumulate the last value too
328328
if bi == 0x2 * block_size - 0x1
329-
if iblock < num_chunks - 0x1
329+
if ichunk < num_chunks - 0x1
330330
temp[bi + bank_offset_b + 0x1] = op(t2, v[
331331
input_base_idx +
332332
((ichunk + 0x1) * block_size * 0x2 - 0x1) * vstrides[dims] +

test/accumulate.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,17 @@ end
192192
# Test that undefined kwargs are not accepted
193193
@test_throws MethodError AK.accumulate(+, v; init=10, dims=2, inclusive=false, bad=:kwarg)
194194

195+
# Test all options with bigger matrices
196+
for D in [(1_000_000,3), (3,1_000_000)], dims in [1,2]
197+
@testset let D = D, dims = dims
198+
vh = ones(Float32, D)
199+
v = array_from_host(vh)
200+
s = AK.accumulate(+, v; init=0, dims)
201+
sh = Array(s)
202+
@test sh == accumulate(+, vh; init=0, dims)
203+
end
204+
end
205+
195206
# Testing different settings
196207
AK.accumulate(
197208
(x, y) -> x + 1,

0 commit comments

Comments
 (0)