Skip to content

Commit 3b31fee

Browse files
Adapt the checkpointer to a SeaIceModel (#90)
* try it out * add a testset * add the code credit * Update sea_ice_momentum_equations.jl * Fix auxiliary fields reference in prognostic_fields * Update Oceananigans compat * Update test_checkpointing.jl * pass * add the set to clock --------- Co-authored-by: Navid C. Constantinou <[email protected]>
1 parent 4529d3b commit 3b31fee

File tree

10 files changed

+160
-8
lines changed

10 files changed

+160
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ SeawaterPolynomials = "d496a93d-167e-4197-9f49-d3af4ff8fe40"
1616
Adapt = "3, 4"
1717
KernelAbstractions = "0.9"
1818
JLD2 = "0.6.2"
19-
Oceananigans = "0.99, 0.100, 0.101, 0.102, 0.103, 0.104"
19+
Oceananigans = "0.100, 0.101, 0.102, 0.103, 0.104"
2020
RootSolvers = "0.3, 0.4"
2121
Roots = "2"
2222
SeawaterPolynomials = "0.3.4"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
99
[compat]
1010
CairoMakie = "0.11.0, 0.12.0"
1111
Documenter = "1"
12-
Oceananigans = "0.95.3 - 0.99"
12+
Oceananigans = "0.100, 0.101, 0.102"

src/Rheologies/Rheologies.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@ sea ice stresses.
2727
"""
2828
Auxiliaries(rheology, grid::AbstractGrid) = Auxiliaries(NamedTuple(), nothing)
2929

30+
import Oceananigans: prognostic_fields
31+
3032
# Nothing rheology
3133
initialize_rheology!(model, rheology) = nothing
34+
3235
compute_stresses!(dynamics, fields, grid, rheology, Δt) = nothing
36+
prognostic_fields(mom, rheology) = NamedTuple()
3337

3438
# Nothing rheology or viscous rheology
3539
@inline compute_substep_Δtᶠᶜᶜ(i, j, grid, Δt, rheology, substeps, fields) = Δt / substeps

src/Rheologies/elasto_visco_plastic_rheology.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ function Auxiliaries(r::ElastoViscoPlasticRheology, grid::AbstractGrid)
141141
return Auxiliaries(fields, kernels)
142142
end
143143

144+
prognostic_fields(mom, ::ElastoViscoPlasticRheology) =
145+
(σ₁₁ = mom.auxiliaries.fields.σ₁₁,
146+
σ₂₂ = mom.auxiliaries.fields.σ₂₂,
147+
σ₁₂ = mom.auxiliaries.fields.σ₁₂)
148+
144149
# Extend the `adapt_structure` function for the ElastoViscoPlasticRheology
145150
Adapt.adapt_structure(to, r::ElastoViscoPlasticRheology) =
146151
ElastoViscoPlasticRheology(Adapt.adapt(to, r.ice_compressive_strength),

src/SeaIceDynamics/SeaIceDynamics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ using ClimaSeaIce.Rheologies: ∂ⱼ_σ₁ⱼ,
2525
compute_substep_Δtᶜᶠᶜ,
2626
sum_of_forcing_u,
2727
sum_of_forcing_v
28-
29-
import Oceananigans: fields
28+
29+
import Oceananigans: fields, prognostic_fields
3030

3131
## A Framework to solve for the ice momentum equation, in the form:
3232
##

src/SeaIceDynamics/sea_ice_momentum_equations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ function SeaIceMomentumEquation(grid;
8080
end
8181

8282
fields(mom::SeaIceMomentumEquation) = mom.auxiliaries.fields
83+
prognostic_fields(mom::SeaIceMomentumEquation) = prognostic_fields(mom, mom.rheology)

src/SeaIceThermodynamics/slab_sea_ice_thermodynamics.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Oceananigans: fields
1+
import Oceananigans: fields, prognostic_fields
22

33
struct ProportionalEvolution end
44

@@ -32,6 +32,7 @@ function Base.show(io::IO, therm::SSIT)
3232
end
3333

3434
fields(therm::SSIT) = (; Tu = therm.top_surface_temperature)
35+
prognostic_fields(therm::SSIT) = (; Gʰ = therm.thermodynamic_tendency)
3536

3637
"""
3738
SlabSeaIceThermodynamics(grid; kw...)

src/sea_ice_model.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ function set!(model::SIM; h=nothing, ℵ=nothing)
152152
return nothing
153153
end
154154

155+
set!(model::SIM, new_clock::Clock) = set!(model.clock, new_clock)
156+
155157
Base.summary(model::SIM) = "SeaIceModel"
156158
prettytime(model::SIM) = prettytime(model.clock.time)
157159
iteration(model::SIM) = model.clock.iteration
@@ -186,8 +188,11 @@ fields(model::SIM) = merge((; h = model.ice_thickness,
186188
fields(model.ice_thermodynamics),
187189
fields(model.dynamics))
188190

191+
prognostic_fields(::Nothing) = NamedTuple()
192+
189193
# TODO: make this correct
190194
prognostic_fields(model::SIM) = merge((; h = model.ice_thickness,
191195
= model.ice_concentration),
192-
model.tracers,
193-
model.velocities)
196+
model.velocities,
197+
prognostic_fields(model.dynamics),
198+
prognostic_fields(model.ice_thermodynamics))

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ end
2525

2626
include("test_sea_ice_advection.jl")
2727
include("test_time_stepping.jl")
28-
include("test_distributed_sea_ice.jl")
28+
include("test_distributed_sea_ice.jl")
29+
include("test_checkpointing.jl")

test/test_checkpointing.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using ClimaSeaIce
2+
using ClimaSeaIce.SeaIceDynamics
3+
using ClimaSeaIce.SeaIceThermodynamics
4+
using Test
5+
6+
using Oceananigans.Fields: @allowscalar
7+
using Oceananigans: prognostic_fields
8+
9+
# # Same test as in Oceananigans
10+
function test_model_equality(test_model, true_model)
11+
@allowscalar begin
12+
test_model_fields = prognostic_fields(test_model)
13+
true_model_fields = prognostic_fields(true_model)
14+
field_names = keys(test_model_fields)
15+
16+
for name in field_names
17+
@test all(test_model_fields[name].data .≈ true_model_fields[name].data)
18+
end
19+
end
20+
21+
return nothing
22+
end
23+
24+
# Test copied from Oceananigans.jl, code credit:
25+
# https://github.com/CliMA/Oceananigans.jl/blob/main/test/test_checkpointer.jl#L94-L174
26+
function run_checkpointer_tests(true_model, test_model, Δt)
27+
true_simulation = Simulation(true_model, Δt=Δt, stop_iteration=5)
28+
29+
checkpointer = Checkpointer(true_model, schedule=IterationInterval(5), overwrite_existing=true)
30+
push!(true_simulation.output_writers, checkpointer)
31+
32+
run!(true_simulation) # for 5 iterations
33+
34+
checkpointed_model = deepcopy(true_simulation.model)
35+
36+
true_simulation.stop_iteration = 9
37+
run!(true_simulation) # for 4 more iterations
38+
39+
#####
40+
##### Test `set!(model, checkpoint_file)`
41+
#####
42+
43+
set!(test_model, "checkpoint_iteration5.jld2")
44+
45+
@test test_model.clock.iteration == checkpointed_model.clock.iteration
46+
@test test_model.clock.time == checkpointed_model.clock.time
47+
test_model_equality(test_model, checkpointed_model)
48+
49+
# This only applies to QuasiAdamsBashforthTimeStepper:
50+
@test test_model.clock.last_Δt == checkpointed_model.clock.last_Δt
51+
52+
#####
53+
##### Test pickup from explicit checkpoint path
54+
#####
55+
56+
test_simulation = Simulation(test_model, Δt=Δt, stop_iteration=9)
57+
58+
# Pickup from explicit checkpoint path
59+
run!(test_simulation, pickup="checkpoint_iteration0.jld2")
60+
61+
@info "Testing model equality when running with pickup=checkpoint_iteration0.jld2."
62+
@test test_simulation.model.clock.iteration == true_simulation.model.clock.iteration
63+
@test test_simulation.model.clock.time == true_simulation.model.clock.time
64+
test_model_equality(test_model, true_model)
65+
66+
run!(test_simulation, pickup="checkpoint_iteration5.jld2")
67+
@info "Testing model equality when running with pickup=checkpoint_iteration5.jld2."
68+
69+
@test test_simulation.model.clock.iteration == true_simulation.model.clock.iteration
70+
@test test_simulation.model.clock.time == true_simulation.model.clock.time
71+
test_model_equality(test_model, true_model)
72+
73+
#####
74+
##### Test `run!(sim, pickup=true)
75+
#####
76+
77+
# Pickup using existing checkpointer
78+
test_simulation.output_writers[:checkpointer] =
79+
Checkpointer(test_model, schedule=IterationInterval(5), overwrite_existing=true)
80+
81+
run!(test_simulation, pickup=true)
82+
@info " Testing model equality when running with pickup=true."
83+
84+
@test test_simulation.model.clock.iteration == true_simulation.model.clock.iteration
85+
@test test_simulation.model.clock.time == true_simulation.model.clock.time
86+
test_model_equality(test_model, true_model)
87+
88+
run!(test_simulation, pickup=0)
89+
@info " Testing model equality when running with pickup=0."
90+
91+
@test test_simulation.model.clock.iteration == true_simulation.model.clock.iteration
92+
@test test_simulation.model.clock.time == true_simulation.model.clock.time
93+
test_model_equality(test_model, true_model)
94+
95+
run!(test_simulation, pickup=5)
96+
@info " Testing model equality when running with pickup=5."
97+
98+
@test test_simulation.model.clock.iteration == true_simulation.model.clock.iteration
99+
@test test_simulation.model.clock.time == true_simulation.model.clock.time
100+
test_model_equality(test_model, true_model)
101+
102+
rm("checkpoint_iteration0.jld2", force=true)
103+
rm("checkpoint_iteration5.jld2", force=true)
104+
105+
return nothing
106+
end
107+
108+
function test_sea_ice_checkpointer_output(arch)
109+
# Create and run "true model"
110+
Nx, Ny = 16, 16
111+
Lx, Ly = 100, 100
112+
Δt = 1
113+
114+
grid = RectilinearGrid(arch, size=(Nx, Ny), x=(0, Lx), y=(0, Ly), topology=(Bounded, Bounded, Flat))
115+
for ice_thermodynamics in (nothing, SlabSeaIceThermodynamics(grid))
116+
for dynamics in (nothing, SeaIceMomentumEquation(grid))
117+
118+
true_model = SeaIceModel(grid; dynamics, ice_thermodynamics)
119+
test_model = deepcopy(true_model)
120+
121+
for field in merge(true_model.velocities,
122+
(; h = true_model.ice_concentration,
123+
= true_model.ice_thickness))
124+
125+
set!(field, (x, y) -> rand() * 1e-5)
126+
end
127+
128+
run_checkpointer_tests(true_model, test_model, Δt)
129+
end
130+
end
131+
end
132+
133+
@testset "Checkpointing Tests" begin
134+
test_sea_ice_checkpointer_output(CPU())
135+
end

0 commit comments

Comments
 (0)