Skip to content

Commit 888d3db

Browse files
authored
Merge pull request #1397 from CliMA/js/checkpoints
delete previous checkpoint when saving a new one
2 parents 81e4ffe + 45d4acf commit 888d3db

File tree

11 files changed

+98
-14
lines changed

11 files changed

+98
-14
lines changed

NEWS.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@ ClimaCoupler.jl Release Notes
66

77
### ClimaCoupler features
88

9-
#### Remove `dt_save_state_to_disk` and `dt_save_to_sol` options
9+
#### Remove intermediate checkpoints PR[#1397](https://github.com/CliMA/ClimaCoupler.jl/pull/1397)
10+
11+
Throughout the simulation, the previous checkpoint is now deleted whenever a new
12+
one is saved. The most recent checkpoint will always be available, so restarting
13+
is still supported. The field `prev_checkpoint_t` in the CoupledSimulation object
14+
is used to remove intermediate checkpoints.
15+
16+
#### Remove `dt_save_state_to_disk` and `dt_save_to_sol` options PR[#1394](https://github.com/CliMA/ClimaCoupler.jl/pull/1394)
1017

1118
`dt_save_state_to_disk` was unused and is removed from all configs in this
1219
commit. Note that ClimaAtmos does have an option with this name, but we
1320
pass `checkpoint_dt` to it. `dt_save_to_sol` is also removed as an option,
1421
in favor of using our more robust checkpointing infrastructure via `checkpoint_dt`.
1522

16-
#### Misc. interface cleanup
23+
#### Misc. interface cleanup PR[#1341](https://github.com/CliMA/ClimaCoupler.jl/pull/1341)
1724

1825
Including:
1926
- Remove `ρ_sfc` from surface model caches

experiments/ClimaEarth/setup_run.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,15 @@ function CoupledSimulation(config_dict::AbstractDict)
455455
specifics, the callbacks, the directory paths, and diagnostics for AMIP simulations.
456456
=#
457457

458+
prev_checkpoint_t = Ref(-1) # no checkpoint taken yet
458459
cs = CoupledSimulation{FT}(
459460
Ref(start_date),
460461
coupler_fields,
461462
conservation_checks,
462463
[tspan[1], tspan[2]],
463464
Δt_cpl,
464465
Ref(tspan[1]),
466+
prev_checkpoint_t,
465467
model_sims,
466468
callbacks,
467469
dir_paths,

experiments/ClimaEarth/test/component_model_tests/climaland_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ end
8383
tspan,
8484
dt,
8585
tspan[1],
86+
Ref(-1), # prev_checkpoint_t
8687
model_sims,
8788
(;), # callbacks
8889
(;), # dirs

experiments/ClimaEarth/test/debug_plots_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ plot_field_names(sim::Interfacer.SurfaceStub) = (:stub_field,)
7676
(Int(0), Int(1)), # tspan
7777
Int(200), # Δt_cpl
7878
Ref(Int(0)), # t
79+
Ref(-1), # prev_checkpoint_t
7980
model_sims, # model_sims
8081
(;), # callbacks
8182
(;), # dirs

experiments/ClimaEarth/test/restart.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ four_steps["dt"] = "180secs"
4646
four_steps["dt_cpl"] = "180secs"
4747
four_steps["t_end"] = "720secs"
4848
four_steps["dt_rad"] = "180secs"
49-
four_steps["checkpoint_dt"] = "720secs"
49+
four_steps["checkpoint_dt"] = "360secs"
5050
four_steps["coupler_output_dir"] = tmpdir
5151
four_steps["job_id"] = "four_steps"
5252

src/Checkpointer.jl

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,24 @@ This is a template function that should be implemented for each component model.
3333
get_model_cache(sim::Interfacer.ComponentModelSimulation) = nothing
3434

3535
"""
36-
checkpoint_model_state(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output")
36+
checkpoint_model_state(
37+
sim::Interfacer.ComponentModelSimulation,
38+
comms_ctx::ClimaComms.AbstractCommsContext,
39+
t::Int,
40+
prev_checkpoint_t::Int;
41+
output_dir = "output")
3742
3843
Checkpoint the model state of a simulation to a HDF5 file at a given time, t (in seconds).
44+
45+
If a previous checkpoint exists, it is removed. This is to avoid accumulating
46+
many checkpoint files in the output directory. A value of -1 for `prev_checkpoint_t`
47+
is used to indicate that there is no previous checkpoint to remove.
3948
"""
4049
function checkpoint_model_state(
4150
sim::Interfacer.ComponentModelSimulation,
4251
comms_ctx::ClimaComms.AbstractCommsContext,
43-
t::Int;
52+
t::Int,
53+
prev_checkpoint_t::Int;
4454
output_dir = "output",
4555
)
4656
Y = get_model_prog_state(sim)
@@ -52,23 +62,36 @@ function checkpoint_model_state(
5262
CC.InputOutput.HDF5.write_attribute(checkpoint_writer.file, "time", t)
5363
CC.InputOutput.write!(checkpoint_writer, Y, "model_state")
5464
Base.close(checkpoint_writer)
55-
return nothing
5665

66+
# Remove previous checkpoint if it exists
67+
prev_checkpoint_file = joinpath(output_dir, "checkpoint_$(nameof(sim))_$(prev_checkpoint_t).hdf5")
68+
remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
69+
return nothing
5770
end
5871

5972
"""
60-
checkpoint_model_cache(sim::Interfacer.ComponentModelSimulation, comms_ctx::ClimaComms.AbstractCommsContext, t::Int; output_dir = "output")
73+
checkpoint_model_cache(
74+
sim::Interfacer.ComponentModelSimulation,
75+
comms_ctx::ClimaComms.AbstractCommsContext,
76+
t::Int,
77+
prev_checkpoint_t::Int;
78+
output_dir = "output")
6179
6280
Checkpoint the model cache to N JLD2 files at a given time, t (in seconds),
6381
where N is the number of MPI ranks.
6482
6583
Objects are saved to JLD2 files because caches are generally not ClimaCore
6684
objects (and ClimaCore.InputOutput can only save `Field`s or `FieldVector`s).
85+
86+
If a previous checkpoint exists, it is removed. This is to avoid accumulating
87+
many checkpoint files in the output directory. A value of -1 for `prev_checkpoint_t`
88+
is used to indicate that there is no previous checkpoint to remove.
6789
"""
6890
function checkpoint_model_cache(
6991
sim::Interfacer.ComponentModelSimulation,
7092
comms_ctx::ClimaComms.AbstractCommsContext,
71-
t::Int;
93+
t::Int,
94+
prev_checkpoint_t::Int;
7295
output_dir = "output",
7396
)
7497
# Move p to CPU (because we cannot save CUArrays)
@@ -79,6 +102,10 @@ function checkpoint_model_cache(
79102
pid = ClimaComms.mypid(comms_ctx)
80103
output_file = joinpath(output_dir, "checkpoint_cache_$(pid)_$(nameof(sim))_$t.jld2")
81104
JLD2.jldsave(output_file, cache = p)
105+
106+
# Remove previous checkpoint if it exists
107+
prev_checkpoint_file = joinpath(output_dir, "checkpoint_cache_$(pid)_$(nameof(sim))_$(prev_checkpoint_t).jld2")
108+
remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
82109
return nothing
83110
end
84111

@@ -104,21 +131,31 @@ function checkpoint_sims(cs::Interfacer.CoupledSimulation)
104131
day = floor(Int, time / (60 * 60 * 24))
105132
sec = floor(Int, time % (60 * 60 * 24))
106133
output_dir = cs.dirs.checkpoints
134+
prev_checkpoint_t = cs.prev_checkpoint_t[]
107135
comms_ctx = ClimaComms.context(cs)
136+
108137
for sim in cs.model_sims
109138
if !isnothing(Checkpointer.get_model_prog_state(sim))
110-
Checkpointer.checkpoint_model_state(sim, comms_ctx, time; output_dir)
139+
Checkpointer.checkpoint_model_state(sim, comms_ctx, time, prev_checkpoint_t; output_dir)
111140
end
112141
if !isnothing(Checkpointer.get_model_cache(sim))
113-
Checkpointer.checkpoint_model_cache(sim, comms_ctx, time; output_dir)
142+
Checkpointer.checkpoint_model_cache(sim, comms_ctx, time, prev_checkpoint_t; output_dir)
114143
end
115144
end
145+
116146
# Checkpoint the Coupler fields
117147
pid = ClimaComms.mypid(comms_ctx)
118148
@info "Saving coupler fields to JLD2 on day $day second $sec"
119149
output_file = joinpath(output_dir, "checkpoint_coupler_fields_$(pid)_$time.jld2")
120150
# Adapt to Array move fields to the CPU
121151
JLD2.jldsave(output_file, coupler_fields = CC.Adapt.adapt(Array, cs.fields))
152+
153+
# Remove previous Coupler fields checkpoint if it exists
154+
prev_checkpoint_file = joinpath(output_dir, "checkpoint_coupler_fields_$(pid)_$(prev_checkpoint_t).jld2")
155+
remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
156+
157+
# Update previous checkpoint time stored in the coupled simulation
158+
cs.prev_checkpoint_t[] = time
122159
end
123160

124161
"""
@@ -211,4 +248,19 @@ function t_start_from_checkpoint(checkpoint_dir)
211248
return parse(Int, match(restart_file_rx, latest_restart)[2])
212249
end
213250

251+
"""
252+
remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
253+
254+
Delete the provided checkpoint file on the root process and print a helpful
255+
info message. This can be used to remove intermediate checkpoints, to prevent
256+
saving excessively large amounts of output.
257+
"""
258+
function remove_checkpoint(prev_checkpoint_file, prev_checkpoint_t, comms_ctx)
259+
if ClimaComms.iamroot(comms_ctx) && prev_checkpoint_t != -1 && isfile(prev_checkpoint_file)
260+
@info "Removing previous checkpoint file: $prev_checkpoint_file"
261+
rm(prev_checkpoint_file)
262+
end
263+
return nothing
264+
end
265+
214266
end # module

src/Interfacer.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,28 @@ abstract type AbstractSimulation{FT} end
4646
CoupledSimulation
4747
Stores information needed to run a simulation with the coupler.
4848
"""
49-
struct CoupledSimulation{FT <: Real, D, FV, E, TS, DTI, TT, NTMS <: NamedTuple, CALLBACKS, NTP <: NamedTuple, TP, DH}
49+
struct CoupledSimulation{
50+
FT <: Real,
51+
D,
52+
FV,
53+
E,
54+
TS,
55+
DTI,
56+
TT,
57+
CTT,
58+
NTMS <: NamedTuple,
59+
CALLBACKS,
60+
NTP <: NamedTuple,
61+
TP,
62+
DH,
63+
}
5064
start_date::D
5165
fields::FV
5266
conservation_checks::E
5367
tspan::TS
5468
Δt_cpl::DTI
5569
t::TT
70+
prev_checkpoint_t::CTT
5671
model_sims::NTMS
5772
callbacks::CALLBACKS
5873
dirs::NTP
@@ -63,7 +78,7 @@ end
6378
CoupledSimulation{FT}(args...) where {FT} = CoupledSimulation{FT, typeof.(args)...}(args...)
6479

6580
function Base.show(io::IO, sim::CoupledSimulation)
66-
device_type = nameof(typeof(ClimaComms.device(sim.comms_ctx)))
81+
device_type = nameof(typeof(ClimaComms.device(sim)))
6782
return print(
6883
io,
6984
"Coupled Simulation\n",

test/checkpointer_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ end
2424
comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded())
2525
boundary_space = CC.CommonSpaces.CubedSphereSpace(FT; comms_ctx, radius = FT(6371e3), n_quad_points = 4, h_elem = 4)
2626
t = 1
27+
prev_checkpoint_t = -1
2728
# old sim run
2829
sim = DummySimulation(CC.Fields.FieldVector(T = ones(boundary_space)))
29-
Checkpointer.checkpoint_model_state(sim, comms_ctx, t, output_dir = "test_checkpoint")
30+
Checkpointer.checkpoint_model_state(sim, comms_ctx, t, prev_checkpoint_t, output_dir = "test_checkpoint")
3031

3132
# new sim run
3233
sim_new = DummySimulation(CC.Fields.FieldVector(T = zeros(boundary_space)))

test/conservation_checker_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ for FT in (Float32, Float64)
6464
(Int(0), Int(1000)), # tspan
6565
Int(200), # Δt_cpl
6666
Ref(Int(0)), # t
67+
Ref(-1), # prev_checkpoint_t
6768
model_sims, # model_sims
6869
(;), # callbacks
6970
(;), # dirs
@@ -152,6 +153,7 @@ for FT in (Float32, Float64)
152153
(Int(0), Int(1000)), # tspan
153154
Int(200), # Δt_cpl
154155
Ref(Int(0)), # t
156+
Ref(-1), # prev_checkpoint_t
155157
model_sims, # model_sims
156158
(;), # callbacks
157159
(;), # dirs

test/field_exchanger_tests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ for FT in (Float32, Float64)
141141
(Int(0), Int(1000)), # tspan
142142
Int(200), # Δt_cpl
143143
Ref(Int(0)), # t
144+
Ref(-1), # prev_checkpoint_t
144145
(;
145146
ice_sim = DummyStub((; area_fraction = ice_d)),
146147
ocean_sim = Interfacer.SurfaceStub((; area_fraction = ocean_d)),
@@ -382,7 +383,8 @@ for FT in (Float32, Float64)
382383
nothing, # conservation_checks
383384
nothing, # tspan
384385
nothing, # dt
385-
nothing, # t_start
386+
nothing, # t
387+
Ref(-1), # prev_checkpoint_t
386388
model_sims,
387389
(;), # callbacks
388390
(;), # dirs

0 commit comments

Comments
 (0)