Skip to content

Commit 8cdf3f3

Browse files
committed
Fix correctness in cuda_mapreduce
`cuda_mapreduce` was not working correctly with certain spaces. Why was this happening? I added a comment to describe the algorithm in the commit. In a nutshell, the algorithm was not taking into account the fact that the final block is not completely filled with points to process. Therefore, the reduction included some elements that did not contain real points (but the value 0).
1 parent b4a04d8 commit 8cdf3f3

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ main
1515

1616
### ![][badge-🐛bugfix] Bug fixes
1717

18-
- Fixed writing/reading purely vertical spaces
18+
- Fixed writing/reading purely vertical spaces. PR [2102](https://github.com/CliMA/ClimaCore.jl/pull/2102)
19+
- Fixed correctness bug in reductions on GPUs. PR [2106](https://github.com/CliMA/ClimaCore.jl/pull/2106)
1920

2021
v0.14.20
2122
--------

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@ function mapreduce_cuda(
3131
weighted_jacobian = OnesArray(parent(data)),
3232
opargs...,
3333
)
34+
# This function implements the following parallel reduction algorithm:
35+
#
36+
# Each thread in each blocks processes multiple data points at the same time
37+
# (n_ops_on_load) each and we perform a block-wise reduction, with each
38+
# block writing to an array of (block-)shared memory. This array has the
39+
# same size as the block, ie, it is as long as many threads are available.
40+
# Processing multiple points means that we apply the reduction to the point
41+
# with index reduction[thread_index] = f(thread_index, thread_index +
42+
# OFFSET), with various OFFSETS that depend on `n_ops_on_load` and block
43+
# size.
44+
#
45+
# For the purpose of indexing, this is equivalent to having larger blocks
46+
# with size effective_blksize = blksize * (n_ops_on_load + 1).
47+
#
48+
#
49+
# After this operation, we have reduced all the data by a factor of
50+
# 1/n_ops_on_load and have results in various arrays `reduction` (one per
51+
# block)
52+
#
53+
# Once we have all the blocks reduced, we perform a tree reduction within
54+
# the block and "move" the reduced value to the first element of the array.
55+
# In this, one of the things to watch out for is that the last block might
56+
# not necessarily have all threads doing work, so we have to be careful to
57+
# not include data in `reduction` that did not have corresponding work.
58+
# Threads of index 1 will write that array into an output array.
59+
#
60+
# The output array has size nblocks, so we do another round of reduction,
61+
# but this time we put each Field in a different block.
62+
3463
S = eltype(data)
3564
pdata = parent(data)
3665
T = eltype(pdata)
@@ -112,7 +141,13 @@ function mapreduce_cuda_kernel!(
112141
end
113142
end
114143
sync_threads()
115-
_cuda_intrablock_reduce!(op, reduction, tidx, blksize)
144+
145+
# The last block might not have enough threads to fill `reduction`, so some
146+
# of its elements might still have the value at initialization.
147+
blksize_for_reduction =
148+
min(blksize, nitems - effective_blksize * (bidx - 1))
149+
150+
_cuda_intrablock_reduce!(op, reduction, tidx, blksize_for_reduction)
116151

117152
tidx == 1 && (reduce_cuda[bidx, fidx] = reduction[1])
118153
return nothing

test/DataLayouts/unit_mapreduce.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,35 @@ end
162162
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); test_mapreduce_2!(context, data_view(data)) # TODO: test
163163
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_mapreduce_2!(context, data_view(data)) # TODO: test
164164
end
165+
166+
@testset "mapreduce with space with some non-round blocks" begin
167+
# https://github.com/CliMA/ClimaCore.jl/issues/2097
168+
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
169+
x_min = 0,
170+
x_max = 1,
171+
y_min = 0,
172+
y_max = 1,
173+
periodic_x = false,
174+
periodic_y = false,
175+
n_quad_points = 4,
176+
x_elem = 129,
177+
y_elem = 129,
178+
)
179+
@test minimum(ones(space)) == 1
180+
181+
if ClimaComms.context isa ClimaComms.SingletonCommsContext
182+
# Less than 256 threads
183+
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
184+
x_min = 0,
185+
x_max = 1,
186+
y_min = 0,
187+
y_max = 1,
188+
periodic_x = false,
189+
periodic_y = false,
190+
n_quad_points = 2,
191+
x_elem = 2,
192+
y_elem = 2,
193+
)
194+
@test minimum(ones(space)) == 1
195+
end
196+
end

0 commit comments

Comments
 (0)