Skip to content

Commit 6838f81

Browse files
authored
Merge pull request #2169 from CliMA/gb/distributed
Fix distributed remapping bug
2 parents 3006168 + 5d07152 commit 6838f81

File tree

4 files changed

+40
-141
lines changed

4 files changed

+40
-141
lines changed

.github/workflows/JuliaFormatter.yml

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,9 @@ on:
77

88
jobs:
99
format:
10-
runs-on: ubuntu-24.04
11-
timeout-minutes: 30
10+
runs-on: ubuntu-latest
1211
steps:
13-
- name: Cancel Previous Runs
14-
uses: styfle/[email protected]
15-
with:
16-
access_token: ${{ github.token }}
17-
18-
- uses: actions/checkout@v4
19-
20-
- uses: dorny/[email protected]
21-
id: filter
22-
with:
23-
filters: |
24-
julia_file_change:
25-
- added|modified: '**.jl'
26-
27-
- uses: julia-actions/setup-julia@v2
28-
if: steps.filter.outputs.julia_file_change == 'true'
29-
with:
30-
version: '1.10'
31-
32-
- name: Apply JuliaFormatter
33-
if: steps.filter.outputs.julia_file_change == 'true'
34-
run: |
35-
julia --color=yes --project=.dev .dev/climaformat.jl --verbose .
36-
37-
- name: Check formatting diff
38-
if: steps.filter.outputs.julia_file_change == 'true'
39-
run: |
40-
git diff --color=always --exit-code
12+
- uses: julia-actions/julia-format@v3
13+
with:
14+
version: '1' # Set `version` to '1.0.54' if you need to use JuliaFormatter.jl v1.0.54 (default: '1')
15+
suggestion-label: 'format-suggest' # leave this unset or empty to show suggestions for all PRs

NEWS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ main
55
-------
66

77
- Prior to this version, `CommonSpaces` could not be created with
8-
`ClimaComms.MPICommContext`. This is now fixed with PR
8+
`ClimaComms.MPICommsContext`. This is now fixed with PR
99
[2176](https://github.com/CliMA/ClimaCore.jl/pull/2176).
10-
10+
- Fixed bug in distributed remapping with CUDA. Sometimes, `ClimaCore` would not
11+
properly fill the output arrays with the correct values. This is now fixed. PR
12+
[2169](https://github.com/CliMA/ClimaCore.jl/pull/2169)
1113

1214
v0.14.24
1315
-------

src/Remapping/distributed_remapping.jl

Lines changed: 29 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
744744
fill!(remapper._interpolated_values, 0)
745745
end
746746

747-
"""
748-
_collect_and_return_interpolated_values!(remapper::Remapper,
749-
num_fields::Int)
750-
751-
Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
752-
the result in the local state of the `remapper`. Only the root process will return the
753-
interpolated data.
754-
755-
`_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.
756-
757-
`num_fields` is the number of fields that have been interpolated in this batch.
758-
"""
759-
function _collect_and_return_interpolated_values!(
760-
remapper::Remapper,
761-
num_fields::Int,
762-
)
763-
return ClimaComms.reduce(
764-
remapper.comms_ctx,
765-
remapper._interpolated_values[remapper.colons..., 1:num_fields],
766-
+,
767-
)
768-
end
769-
770747
function _collect_interpolated_values!(
771748
dest,
772749
remapper::Remapper,
@@ -777,38 +754,26 @@ function _collect_interpolated_values!(
777754
if only_one_field
778755
ClimaComms.reduce!(
779756
remapper.comms_ctx,
780-
remapper._interpolated_values[remapper.colons..., begin],
757+
view(remapper._interpolated_values, remapper.colons..., 1),
781758
dest,
782759
+,
783760
)
784-
return nothing
761+
else
762+
num_fields = 1 + index_field_end - index_field_begin
763+
ClimaComms.reduce!(
764+
remapper.comms_ctx,
765+
view(
766+
remapper._interpolated_values,
767+
remapper.colons...,
768+
1:num_fields,
769+
),
770+
view(dest, remapper.colons..., index_field_begin:index_field_end),
771+
+,
772+
)
785773
end
786-
787-
num_fields = 1 + index_field_end - index_field_begin
788-
789-
ClimaComms.reduce!(
790-
remapper.comms_ctx,
791-
view(remapper._interpolated_values, remapper.colons..., 1:num_fields),
792-
view(dest, remapper.colons..., index_field_begin:index_field_end),
793-
+,
794-
)
795-
796774
return nothing
797775
end
798776

799-
"""
800-
batched_ranges(num_fields, buffer_length)
801-
802-
Partition the indices from 1 to num_fields in such a way that no range is larger than
803-
buffer_length.
804-
"""
805-
function batched_ranges(num_fields, buffer_length)
806-
return [
807-
(i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for
808-
i in 0:(div((num_fields - 1), buffer_length))
809-
]
810-
end
811-
812777
"""
813778
interpolate(remapper::Remapper, fields)
814779
interpolate!(dest, remapper::Remapper, fields)
@@ -860,58 +825,21 @@ int12 = interpolate(remapper, [field1, field2])
860825
```
861826
"""
862827
function interpolate(remapper::Remapper, fields)
863-
828+
ArrayType = ClimaComms.array_type(remapper.space)
829+
FT = Spaces.undertype(remapper.space)
864830
only_one_field = fields isa Fields.Field
865-
if only_one_field
866-
fields = [fields]
867-
end
868831

869-
for field in fields
870-
axes(field) == remapper.space ||
871-
error("Field is defined on a different space than remapper")
872-
end
832+
interpolated_values_dim..., _buffer_length =
833+
size(remapper._interpolated_values)
873834

874-
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace
875-
876-
index_field_begin, index_field_end =
877-
1, min(length(fields), remapper.buffer_length)
878-
879-
# Partition the indices in such a way that nothing is larger than
880-
# buffer_length
881-
index_ranges = batched_ranges(length(fields), remapper.buffer_length)
835+
allocate_extra = only_one_field ? () : (length(fields),)
836+
dest = ArrayType(zeros(FT, interpolated_values_dim..., allocate_extra...))
882837

883-
cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)
884-
885-
interpolated_values = mapreduce(cat_fn, index_ranges) do range
886-
num_fields = length(range)
887-
888-
# Reset interpolated_values. This is needed because we collect distributed results
889-
# with a + reduction.
890-
_reset_interpolated_values!(remapper)
891-
# Perform the interpolations (horizontal and vertical)
892-
_set_interpolated_values!(
893-
remapper,
894-
view(fields, index_field_begin:index_field_end),
895-
)
896-
897-
if !isa_vertical_space
898-
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
899-
_apply_mpi_bitmask!(remapper, num_fields)
900-
else
901-
# For purely vertical spaces, just move to _interpolated_values
902-
remapper._interpolated_values .= remapper._local_interpolated_values
903-
end
904-
905-
# Finally, we have to send all the _interpolated_values to root and sum them up to
906-
# obtain the final answer. Only the root will contain something useful.
907-
return _collect_and_return_interpolated_values!(remapper, num_fields)
908-
end
909-
910-
# Non-root processes
911-
isnothing(interpolated_values) && return nothing
912-
913-
return only_one_field ? interpolated_values[remapper.colons..., begin] :
914-
interpolated_values
838+
# interpolate! has an MPI call, so it is important to return after it is
839+
# called, not before!
840+
interpolate!(dest, remapper, fields)
841+
ClimaComms.iamroot(remapper.comms_ctx) || return nothing
842+
return dest
915843
end
916844

917845
# dest has to be allowed to be nothing because interpolation happens only on the root
@@ -927,6 +855,11 @@ function interpolate!(
927855
end
928856
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace
929857

858+
for field in fields
859+
axes(field) == remapper.space ||
860+
error("Field is defined on a different space than remapper")
861+
end
862+
930863
if !isnothing(dest)
931864
# !isnothing(dest) means that this is the root process, in this case, the size have
932865
# to match (ignoring the buffer_length)

test/Remapping/distributed_remapping.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ atexit() do
3131
global_logger(prev_logger)
3232
end
3333

34-
@testset "Utils" begin
35-
# batched_ranges(num_fields, buffer_length)
36-
@test Remapping.batched_ranges(1, 1) == [1:1]
37-
@test Remapping.batched_ranges(1, 2) == [1:1]
38-
@test Remapping.batched_ranges(2, 2) == [1:2]
39-
@test Remapping.batched_ranges(3, 2) == [1:2, 3:3]
40-
end
41-
4234
with_mpi = context isa ClimaComms.MPICommsContext
4335

4436
@testset "2D extruded" begin
@@ -161,10 +153,7 @@ end
161153

162154
quad = Quadratures.GLL{4}()
163155
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
164-
horztopology = Topologies.Topology2D(
165-
ClimaComms.SingletonCommsContext(device),
166-
horzmesh,
167-
)
156+
horztopology = Topologies.Topology2D(context, horzmesh)
168157
horzspace = Spaces.SpectralElementSpace2D(horztopology, quad)
169158

170159
hv_center_space =
@@ -330,7 +319,7 @@ end
330319
quad = Quadratures.GLL{4}()
331320
horzmesh = Meshes.RectilinearMesh(horzdomain, 10, 10)
332321
horztopology = Topologies.Topology2D(
333-
ClimaComms.SingletonCommsContext(device),
322+
context,
334323
horzmesh,
335324
Topologies.spacefillingcurve(horzmesh),
336325
)

0 commit comments

Comments
 (0)