Skip to content

Commit 92c7bc3

Browse files
committed
Fix for #2358
1 parent 450b39c commit 92c7bc3

File tree

4 files changed

+117
-16
lines changed

4 files changed

+117
-16
lines changed

src/Grids/level.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
struct LevelGrid{
22
G <: AbstractExtrudedFiniteDifferenceGrid,
33
L <: Union{Int, PlusHalf{Int}},
4+
Q <: Quadratures.QuadratureStyle,
45
} <: AbstractGrid
56
full_grid::G
67
level::L
8+
quadrature_style::Q
79
end
810

9-
quadrature_style(levelgrid::LevelGrid) =
10-
quadrature_style(levelgrid.full_grid.horizontal_grid)
11+
quadrature_style(levelgrid::LevelGrid) = levelgrid.quadrature_style
1112

1213
level(
1314
grid::AbstractExtrudedFiniteDifferenceGrid,
1415
level::Union{Int, PlusHalf{Int}},
15-
) = LevelGrid(grid, level)
16+
) = LevelGrid(grid, level, quadrature_style(grid))
1617

1718
topology(levelgrid::LevelGrid) = topology(levelgrid.full_grid)
1819

@@ -21,22 +22,28 @@ topology(levelgrid::LevelGrid) = topology(levelgrid.full_grid)
2122
# need to extract the weights at a particular level.
2223
dss_weights(levelgrid::LevelGrid, _) = dss_weights(levelgrid.full_grid, nothing)
2324

24-
local_geometry_type(::Type{LevelGrid{G, L}}) where {G, L} =
25+
local_geometry_type(::Type{LevelGrid{G, L, Q}}) where {G, L, Q} =
2526
local_geometry_type(G)
26-
local_geometry_data(levelgrid::LevelGrid{<:Any, Int}, ::Nothing) = level(
27+
28+
local_geometry_data(levelgrid::LevelGrid{<:Any, Int, <:Any}, ::Nothing) = level(
2729
local_geometry_data(levelgrid.full_grid, CellCenter()),
2830
levelgrid.level,
2931
)
30-
local_geometry_data(levelgrid::LevelGrid{<:Any, PlusHalf{Int}}, ::Nothing) =
31-
level(
32-
local_geometry_data(levelgrid.full_grid, CellFace()),
33-
levelgrid.level + half,
34-
)
35-
global_geometry(levlgrid::LevelGrid) = global_geometry(levlgrid.full_grid)
32+
local_geometry_data(
33+
levelgrid::LevelGrid{<:Any, PlusHalf{Int}, <:Any},
34+
::Nothing,
35+
) = level(
36+
local_geometry_data(levelgrid.full_grid, CellFace()),
37+
levelgrid.level + half,
38+
)
39+
global_geometry(levelgrid::LevelGrid) = global_geometry(levelgrid.full_grid)
3640

3741
## GPU compatibility
38-
Adapt.adapt_structure(to, grid::LevelGrid) =
39-
LevelGrid(Adapt.adapt(to, grid.full_grid), grid.level)
42+
Adapt.adapt_structure(to, grid::LevelGrid) = LevelGrid(
43+
Adapt.adapt(to, grid.full_grid),
44+
grid.level,
45+
quadrature_style(grid),
46+
)
4047

4148
## aliases
4249
const LevelCubedSphereSpectralElementGrid2D =

src/InputOutput/readers.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,24 @@ function read_grid_new(reader, name)
522522
else
523523
level = attrs(group)["level_half"] + half
524524
end
525-
return Grids.LevelGrid(full_grid, level)
525+
526+
# Check if quadrature attributes exist in current group, otherwise use horizontal_grid
527+
if haskey(attrs(group), "quadrature_num_points") &&
528+
haskey(attrs(group), "quadrature_type")
529+
group = reader.file["grids/horizontal_grid"]
530+
npts = attrs(group)["quadrature_num_points"]
531+
quadrature_style =
532+
_scan_quadrature_style(attrs(group)["quadrature_type"], npts)
533+
else
534+
horizontal_group = reader.file["grids/horizontal_grid"]
535+
npts = attrs(horizontal_group)["quadrature_num_points"]
536+
quadrature_style = _scan_quadrature_style(
537+
attrs(horizontal_group)["quadrature_type"],
538+
npts,
539+
)
540+
end
541+
542+
return Grids.LevelGrid(full_grid, level, quadrature_style)
526543
else
527544
error("Unsupported grid type $type")
528545
end

test/Fields/unit_field.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,31 @@ end
689689
end
690690
end
691691

692+
@testset "Levels of nonlocal Fields and nonlocal Field broadcasts" begin
693+
FT = Float64
694+
gradh = Operators.Gradient()
695+
# Todo: Make this work over all spaces; currently broken for everything else.
696+
for space in (
697+
TU.CenterExtrudedFiniteDifferenceSpace(FT),
698+
TU.FaceExtrudedFiniteDifferenceSpace(FT),
699+
)
700+
TU.levelable(space) || continue
701+
field = fill((; x = FT(1)), space)
702+
703+
op_on_level_of_field = gradh.(Fields.Field(
704+
Spaces.level(Fields.field_values(field.x), 1),
705+
Spaces.level(space, TU.fc_index(1, space)),
706+
))
707+
708+
@test op_on_level_of_field ==
709+
(Spaces.level(gradh.(field.x), TU.fc_index(1, space)))
710+
711+
@test_broken op_on_level_of_field == Base.materialize(
712+
(Spaces.level(lazy.(gradh.(field.x)), TU.fc_index(1, space))),
713+
)
714+
end
715+
end
716+
692717
@testset "Columns of Fields and Field broadcasts" begin
693718
FT = Float64
694719
for space in TU.all_spaces(FT)
@@ -736,7 +761,7 @@ end
736761
is_cuda && space isa OneSlabIndexSpace
737762
@test slab_of_field == Base.materialize(
738763
Spaces.slab(lazy.(identity.(field)), indices...),
739-
) broken = is_cuda
764+
) broken = is_cuda && space isa OneSlabIndexSpace
740765
# TODO: Figure out why some of these tests are broken on GPUs.
741766
end
742767
end

test/InputOutput/unit_hybrid3dcubedsphere_topography.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ using ClimaCore:
1313
InputOutput,
1414
Grids
1515

16+
@isdefined(TU) || include(
17+
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
18+
);
19+
import .TestUtilities as TU;
20+
1621
using ClimaComms
1722
const comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded())
1823
pid, nprocs = ClimaComms.init(comms_ctx)
@@ -54,7 +59,7 @@ end
5459
z_max / 8 .* (
5560
cosd.(Fields.coordinate_field(h_space).lat) .+
5661
cosd.(Fields.coordinate_field(h_space).long) .+ 1
57-
)
62+
),
5863
)
5964

6065
z_mesh = Meshes.IntervalMesh(z_domain, nelems = z_elem)
@@ -90,3 +95,50 @@ end
9095
end
9196
end
9297
end
98+
99+
100+
@testset "HDF5 restart test for a Named Tuple of Levels of a 3D hybrid cubed sphere for deep" begin
101+
# This I/O is used for the computation of the topographic drag
102+
FT = Float32
103+
104+
for space in (
105+
TU.CenterExtrudedFiniteDifferenceSpace(FT, context = comms_ctx),
106+
TU.FaceExtrudedFiniteDifferenceSpace(FT, context = comms_ctx),
107+
)
108+
TU.levelable(space) || continue
109+
field = fill((; x = FT(1)), space)
110+
111+
level_of_field = Fields.Field(
112+
Spaces.level(Fields.field_values(field.x), 1),
113+
Spaces.level(space, TU.fc_index(1, space)),
114+
)
115+
116+
fake_drag = fill(
117+
(;
118+
t11 = FT(0.0),
119+
t12 = FT(0.0),
120+
t21 = FT(0.0),
121+
t22 = FT(0.0),
122+
hmin = FT(0.0),
123+
hmax = FT(0.0),
124+
),
125+
axes(level_of_field),
126+
)
127+
128+
# write field vector to hdf5 file
129+
InputOutput.HDF5Writer(filename, comms_ctx) do writer
130+
InputOutput.write!(writer, fake_drag, "fake_drag")
131+
end
132+
133+
InputOutput.HDF5Reader(filename, comms_ctx) do reader
134+
restart_fake_drag = InputOutput.read_field(reader, "fake_drag") # read fieldvector from hdf5 file
135+
136+
# The underlying space is of a different instance, so we cannot use == to check for equivalence.
137+
# Instead, we make sure that the values and types are the same.
138+
@test typeof(restart_fake_drag) == typeof(fake_drag)
139+
@test maximum(
140+
abs.(parent(fake_drag.t21) .- parent(restart_fake_drag.t21)),
141+
) == 0.0f0
142+
end
143+
end
144+
end

0 commit comments

Comments
 (0)