@@ -706,7 +706,24 @@ using ClimaCore.CommonSpaces
706
706
using ClimaCore. Grids
707
707
using Adapt
708
708
709
- function test_adapt (cpu_space_in)
709
+ function test_adapt_types (space_fn; broken_space_type_match = false )
710
+ @static if ClimaComms. device () isa ClimaComms. CUDADevice
711
+ cpu_space = space_fn (ClimaComms. CPUSingleThreaded ())
712
+ gpu_space = space_fn (ClimaComms. CUDADevice ())
713
+ FT = Spaces. undertype (cpu_space)
714
+ f_cpu = Fields. Field (FT, cpu_space)
715
+ f_gpu = Fields. Field (FT, gpu_space)
716
+ f_cpu_from_gpu =
717
+ ClimaCore. to_device (ClimaComms. CPUSingleThreaded (), f_gpu)
718
+ @test typeof (Fields. field_values (f_cpu_from_gpu)) ==
719
+ typeof (Fields. field_values (f_cpu))
720
+ @test typeof (axes (f_cpu_from_gpu)) == typeof (axes (f_cpu)) broken =
721
+ broken_space_type_match
722
+ end
723
+ end
724
+
725
+ function test_adapt (space_fn)
726
+ cpu_space_in = space_fn (ClimaComms. CPUSingleThreaded ())
710
727
test_adapt_space (cpu_space_in)
711
728
cpu_f_in = Fields. Field (Float64, cpu_space_in)
712
729
cpu_f_out = Adapt. adapt (Array, cpu_f_in)
@@ -730,7 +747,10 @@ function test_adapt(cpu_space_in)
730
747
# cpu -> gpu
731
748
gpu_f_out = ClimaCore. to_device (ClimaComms. CUDADevice (), cpu_f_in)
732
749
@test parent (Fields. field_values (gpu_f_out)) isa CUDA. CuArray
733
- # gpu -> gpu
750
+ @test ClimaComms. device (gpu_f_out) isa ClimaComms. CUDADevice
751
+ @test ClimaComms. array_type (gpu_f_out) == CUDA. CuArray
752
+
753
+ # gpu -> cpu
734
754
cpu_f_out =
735
755
ClimaCore. to_device (ClimaComms. CPUSingleThreaded (), gpu_f_out)
736
756
@test parent (Fields. field_values (cpu_f_out)) isa Array
@@ -772,8 +792,8 @@ function test_adapt_space(cpu_space_in)
772
792
end
773
793
774
794
@testset " Test Adapt" begin
775
- space = ExtrudedCubedSphereSpace (;
776
- device = ClimaComms . CPUSingleThreaded () ,
795
+ ecs_space_fn (dev) = ExtrudedCubedSphereSpace (;
796
+ device = dev ,
777
797
z_elem = 10 ,
778
798
z_min = 0 ,
779
799
z_max = 1 ,
@@ -782,27 +802,30 @@ end
782
802
n_quad_points = 4 ,
783
803
staggering = Grids. CellCenter (),
784
804
)
785
- test_adapt (space)
805
+ test_adapt (ecs_space_fn)
806
+ test_adapt_types (ecs_space_fn; broken_space_type_match = true )
786
807
787
- space = CubedSphereSpace (;
788
- device = ClimaComms . CPUSingleThreaded () ,
808
+ cs_space_fn (dev) = CubedSphereSpace (;
809
+ device = dev ,
789
810
radius = 10 ,
790
811
n_quad_points = 4 ,
791
812
h_elem = 10 ,
792
813
)
793
- test_adapt (space)
814
+ test_adapt (cs_space_fn)
815
+ test_adapt_types (cs_space_fn; broken_space_type_match = true )
794
816
795
- space = ColumnSpace (;
796
- device = ClimaComms . CPUSingleThreaded () ,
817
+ column_space_fn (dev) = ColumnSpace (;
818
+ device = dev ,
797
819
z_elem = 10 ,
798
820
z_min = 0 ,
799
821
z_max = 10 ,
800
822
staggering = CellCenter (),
801
823
)
802
- test_adapt (space)
824
+ test_adapt (column_space_fn)
825
+ test_adapt_types (column_space_fn)
803
826
804
- space = Box3DSpace (;
805
- device = ClimaComms . CPUSingleThreaded () ,
827
+ box_space_fn (dev) = Box3DSpace (;
828
+ device = dev ,
806
829
z_elem = 10 ,
807
830
x_min = 0 ,
808
831
x_max = 1 ,
@@ -817,10 +840,11 @@ end
817
840
y_elem = 4 ,
818
841
staggering = CellCenter (),
819
842
)
820
- test_adapt (space)
843
+ test_adapt (box_space_fn)
844
+ test_adapt_types (box_space_fn; broken_space_type_match = true )
821
845
822
- space = SliceXZSpace (;
823
- device = ClimaComms . CPUSingleThreaded () ,
846
+ slice_space_fn (dev) = SliceXZSpace (;
847
+ device = dev ,
824
848
z_elem = 10 ,
825
849
x_min = 0 ,
826
850
x_max = 1 ,
@@ -831,10 +855,11 @@ end
831
855
x_elem = 4 ,
832
856
staggering = CellCenter (),
833
857
)
834
- test_adapt (space)
858
+ test_adapt (slice_space_fn)
859
+ # test_adapt_types(slice_space_fn) # not yet supported on gpus
835
860
836
- space = RectangleXYSpace (;
837
- device = ClimaComms . CPUSingleThreaded () ,
861
+ rect_space_fn (dev) = RectangleXYSpace (;
862
+ device = dev ,
838
863
x_min = 0 ,
839
864
x_max = 1 ,
840
865
y_min = 0 ,
845
870
x_elem = 3 ,
846
871
y_elem = 4 ,
847
872
)
848
- test_adapt (space)
873
+ test_adapt (rect_space_fn)
874
+ test_adapt_types (rect_space_fn; broken_space_type_match = true )
849
875
850
876
# FieldVector
851
877
cspace = ExtrudedCubedSphereSpace (;
0 commit comments