Skip to content

Commit 6539b89

Browse files
authored
Merge pull request #2106 from CliMA/gb/fix_cuda_reductions
Fix correctness in cuda_mapreduce
2 parents b4a04d8 + 8cdf3f3 commit 6539b89

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ main
1515

1616
### ![][badge-🐛bugfix] Bug fixes
1717

18-
- Fixed writing/reading purely vertical spaces
18+
- Fixed writing/reading purely vertical spaces. PR [2102](https://github.com/CliMA/ClimaCore.jl/pull/2102)
19+
- Fixed correctness bug in reductions on GPUs. PR [2106](https://github.com/CliMA/ClimaCore.jl/pull/2106)
1920

2021
v0.14.20
2122
--------

ext/cuda/data_layouts_mapreduce.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@ function mapreduce_cuda(
3131
weighted_jacobian = OnesArray(parent(data)),
3232
opargs...,
3333
)
34+
# This function implements the following parallel reduction algorithm:
35+
#
36+
# Each thread in each blocks processes multiple data points at the same time
37+
# (n_ops_on_load) each and we perform a block-wise reduction, with each
38+
# block writing to an array of (block-)shared memory. This array has the
39+
# same size as the block, ie, it is as long as many threads are available.
40+
# Processing multiple points means that we apply the reduction to the point
41+
# with index reduction[thread_index] = f(thread_index, thread_index +
42+
# OFFSET), with various OFFSETS that depend on `n_ops_on_load` and block
43+
# size.
44+
#
45+
# For the purpose of indexing, this is equivalent to having larger blocks
46+
# with size effective_blksize = blksize * (n_ops_on_load + 1).
47+
#
48+
#
49+
# After this operation, we have reduced all the data by a factor of
50+
# 1/n_ops_on_load and have results in various arrays `reduction` (one per
51+
# block)
52+
#
53+
# Once we have all the blocks reduced, we perform a tree reduction within
54+
# the block and "move" the reduced value to the first element of the array.
55+
# In this, one of the things to watch out for is that the last block might
56+
# not necessarily have all threads doing work, so we have to be careful to
57+
# not include data in `reduction` that did not have corresponding work.
58+
# Threads of index 1 will write that array into an output array.
59+
#
60+
# The output array has size nblocks, so we do another round of reduction,
61+
# but this time we put each Field in a different block.
62+
3463
S = eltype(data)
3564
pdata = parent(data)
3665
T = eltype(pdata)
@@ -112,7 +141,13 @@ function mapreduce_cuda_kernel!(
112141
end
113142
end
114143
sync_threads()
115-
_cuda_intrablock_reduce!(op, reduction, tidx, blksize)
144+
145+
# The last block might not have enough threads to fill `reduction`, so some
146+
# of its elements might still have the value at initialization.
147+
blksize_for_reduction =
148+
min(blksize, nitems - effective_blksize * (bidx - 1))
149+
150+
_cuda_intrablock_reduce!(op, reduction, tidx, blksize_for_reduction)
116151

117152
tidx == 1 && (reduce_cuda[bidx, fidx] = reduction[1])
118153
return nothing

test/DataLayouts/unit_mapreduce.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,35 @@ end
162162
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); test_mapreduce_2!(context, data_view(data)) # TODO: test
163163
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_mapreduce_2!(context, data_view(data)) # TODO: test
164164
end
165+
166+
@testset "mapreduce with space with some non-round blocks" begin
167+
# https://github.com/CliMA/ClimaCore.jl/issues/2097
168+
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
169+
x_min = 0,
170+
x_max = 1,
171+
y_min = 0,
172+
y_max = 1,
173+
periodic_x = false,
174+
periodic_y = false,
175+
n_quad_points = 4,
176+
x_elem = 129,
177+
y_elem = 129,
178+
)
179+
@test minimum(ones(space)) == 1
180+
181+
if ClimaComms.context isa ClimaComms.SingletonCommsContext
182+
# Less than 256 threads
183+
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
184+
x_min = 0,
185+
x_max = 1,
186+
y_min = 0,
187+
y_max = 1,
188+
periodic_x = false,
189+
periodic_y = false,
190+
n_quad_points = 2,
191+
x_elem = 2,
192+
y_elem = 2,
193+
)
194+
@test minimum(ones(space)) == 1
195+
end
196+
end

0 commit comments

Comments
 (0)