Skip to content

Commit 3006168

Browse files
charleskawczynskiCharlie Kawczynski
andauthored
Fix adapt for vert topo and Topology2D (#2187)
Co-authored-by: Charlie Kawczynski <[email protected]>
1 parent 47cbff5 commit 3006168

File tree

3 files changed

+87
-20
lines changed

3 files changed

+87
-20
lines changed

src/Topologies/interval.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct IntervalTopology{
1616
boundaries::B
1717
end
1818

19+
Adapt.@adapt_structure IntervalTopology
20+
1921
## gpu
2022
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
2123
boundaries::B

src/Topologies/topology2d.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,45 @@ mutable struct Topology2D{
124124
ghost_face_neighbor_loc::Vector{Int}
125125
end
126126

127+
function Adapt.adapt_structure(to, topo::Topology2D)
128+
return Topology2D(
129+
Adapt.adapt(to, topo.context),
130+
Adapt.adapt(to, topo.mesh),
131+
Adapt.adapt(to, topo.elemorder),
132+
Adapt.adapt(to, topo.orderindex),
133+
topo.elempid,
134+
topo.local_elem_gidx,
135+
topo.neighbor_pids,
136+
topo.send_elem_lidx,
137+
topo.send_elem_lengths,
138+
topo.recv_elem_gidx,
139+
topo.recv_elem_lengths,
140+
Adapt.adapt(to, topo.interior_faces),
141+
Adapt.adapt(to, topo.ghost_faces),
142+
Adapt.adapt(to, topo.local_vertices),
143+
Adapt.adapt(to, topo.local_vertex_offset),
144+
Adapt.adapt(to, topo.ghost_vertices),
145+
Adapt.adapt(to, topo.ghost_vertex_offset),
146+
Adapt.adapt(to, topo.local_neighbor_elem),
147+
Adapt.adapt(to, topo.local_neighbor_elem_offset),
148+
topo.ghost_neighbor_elem,
149+
topo.ghost_neighbor_elem_offset,
150+
Adapt.adapt(to, topo.boundaries),
151+
topo.internal_elems,
152+
topo.perimeter_elems,
153+
topo.nglobalvertices,
154+
topo.nglobalfaces,
155+
topo.ghost_vertex_gcidx,
156+
topo.ghost_face_gcidx,
157+
topo.comm_vertex_lengths,
158+
topo.comm_face_lengths,
159+
topo.ghost_vertex_neighbor_loc,
160+
topo.ghost_vertex_comm_idx_offset,
161+
Adapt.adapt(to, topo.repr_ghost_vertex),
162+
topo.ghost_face_neighbor_loc,
163+
)
164+
end
165+
127166
ClimaComms.device(topology::Topology2D) = ClimaComms.device(topology.context)
128167
ClimaComms.array_type(topology::Topology2D) =
129168
ClimaComms.array_type(topology.context.device)

test/Fields/unit_field.jl

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,24 @@ using ClimaCore.CommonSpaces
706706
using ClimaCore.Grids
707707
using Adapt
708708

709-
function test_adapt(cpu_space_in)
709+
function test_adapt_types(space_fn; broken_space_type_match = false)
710+
@static if ClimaComms.device() isa ClimaComms.CUDADevice
711+
cpu_space = space_fn(ClimaComms.CPUSingleThreaded())
712+
gpu_space = space_fn(ClimaComms.CUDADevice())
713+
FT = Spaces.undertype(cpu_space)
714+
f_cpu = Fields.Field(FT, cpu_space)
715+
f_gpu = Fields.Field(FT, gpu_space)
716+
f_cpu_from_gpu =
717+
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), f_gpu)
718+
@test typeof(Fields.field_values(f_cpu_from_gpu)) ==
719+
typeof(Fields.field_values(f_cpu))
720+
@test typeof(axes(f_cpu_from_gpu)) == typeof(axes(f_cpu)) broken =
721+
broken_space_type_match
722+
end
723+
end
724+
725+
function test_adapt(space_fn)
726+
cpu_space_in = space_fn(ClimaComms.CPUSingleThreaded())
710727
test_adapt_space(cpu_space_in)
711728
cpu_f_in = Fields.Field(Float64, cpu_space_in)
712729
cpu_f_out = Adapt.adapt(Array, cpu_f_in)
@@ -730,7 +747,10 @@ function test_adapt(cpu_space_in)
730747
# cpu -> gpu
731748
gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in)
732749
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
733-
# gpu -> gpu
750+
@test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice
751+
@test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray
752+
753+
# gpu -> cpu
734754
cpu_f_out =
735755
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out)
736756
@test parent(Fields.field_values(cpu_f_out)) isa Array
@@ -772,8 +792,8 @@ function test_adapt_space(cpu_space_in)
772792
end
773793

774794
@testset "Test Adapt" begin
775-
space = ExtrudedCubedSphereSpace(;
776-
device = ClimaComms.CPUSingleThreaded(),
795+
ecs_space_fn(dev) = ExtrudedCubedSphereSpace(;
796+
device = dev,
777797
z_elem = 10,
778798
z_min = 0,
779799
z_max = 1,
@@ -782,27 +802,30 @@ end
782802
n_quad_points = 4,
783803
staggering = Grids.CellCenter(),
784804
)
785-
test_adapt(space)
805+
test_adapt(ecs_space_fn)
806+
test_adapt_types(ecs_space_fn; broken_space_type_match = true)
786807

787-
space = CubedSphereSpace(;
788-
device = ClimaComms.CPUSingleThreaded(),
808+
cs_space_fn(dev) = CubedSphereSpace(;
809+
device = dev,
789810
radius = 10,
790811
n_quad_points = 4,
791812
h_elem = 10,
792813
)
793-
test_adapt(space)
814+
test_adapt(cs_space_fn)
815+
test_adapt_types(cs_space_fn; broken_space_type_match = true)
794816

795-
space = ColumnSpace(;
796-
device = ClimaComms.CPUSingleThreaded(),
817+
column_space_fn(dev) = ColumnSpace(;
818+
device = dev,
797819
z_elem = 10,
798820
z_min = 0,
799821
z_max = 10,
800822
staggering = CellCenter(),
801823
)
802-
test_adapt(space)
824+
test_adapt(column_space_fn)
825+
test_adapt_types(column_space_fn)
803826

804-
space = Box3DSpace(;
805-
device = ClimaComms.CPUSingleThreaded(),
827+
box_space_fn(dev) = Box3DSpace(;
828+
device = dev,
806829
z_elem = 10,
807830
x_min = 0,
808831
x_max = 1,
@@ -817,10 +840,11 @@ end
817840
y_elem = 4,
818841
staggering = CellCenter(),
819842
)
820-
test_adapt(space)
843+
test_adapt(box_space_fn)
844+
test_adapt_types(box_space_fn; broken_space_type_match = true)
821845

822-
space = SliceXZSpace(;
823-
device = ClimaComms.CPUSingleThreaded(),
846+
slice_space_fn(dev) = SliceXZSpace(;
847+
device = dev,
824848
z_elem = 10,
825849
x_min = 0,
826850
x_max = 1,
@@ -831,10 +855,11 @@ end
831855
x_elem = 4,
832856
staggering = CellCenter(),
833857
)
834-
test_adapt(space)
858+
test_adapt(slice_space_fn)
859+
# test_adapt_types(slice_space_fn) # not yet supported on gpus
835860

836-
space = RectangleXYSpace(;
837-
device = ClimaComms.CPUSingleThreaded(),
861+
rect_space_fn(dev) = RectangleXYSpace(;
862+
device = dev,
838863
x_min = 0,
839864
x_max = 1,
840865
y_min = 0,
@@ -845,7 +870,8 @@ end
845870
x_elem = 3,
846871
y_elem = 4,
847872
)
848-
test_adapt(space)
873+
test_adapt(rect_space_fn)
874+
test_adapt_types(rect_space_fn; broken_space_type_match = true)
849875

850876
# FieldVector
851877
cspace = ExtrudedCubedSphereSpace(;

0 commit comments

Comments
 (0)