Skip to content

Commit c7fc613

Browse files
authored
Merge pull request #79 from TUM-PIK-ESM/bg/modify-timestep-interface
Refactor timestep! interface to improve flexibility
2 parents 636e0e7 + 51a69e8 commit c7fc613

File tree

9 files changed

+70
-37
lines changed

9 files changed

+70
-37
lines changed

src/Terrarium.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,6 @@ export InputSource, InputSources, FieldInputSource, FieldTimeSeriesInputSource
100100
export update_inputs!
101101
include("input_output/input_sources.jl")
102102

103-
# timestepping
104-
export timestep!, default_dt, is_adaptive
105-
include("timesteppers/abstract_timestepper.jl")
106-
107103
# process/model interface
108104
export get_grid, get_initializer, variables, processes, compute_auxiliary!, compute_tendencies!
109105
include("abstract_model.jl")
@@ -124,6 +120,10 @@ include("boundary_conditions.jl")
124120
export Forcings
125121
include("forcings.jl")
126122

123+
# timestepping
124+
export timestep!, default_dt, is_adaptive
125+
include("timesteppers/abstract_timestepper.jl")
126+
127127
# abstract model types
128128
include("models/abstract_types.jl")
129129

@@ -133,14 +133,14 @@ include("processes/processes.jl")
133133
# concrete model implementations
134134
include("models/models.jl")
135135

136-
# timestepper implementations
136+
# model integrator/simulation types and methods
137+
export ModelIntegrator, initialize, current_time, iteration
138+
include("timesteppers/model_integrator.jl")
139+
140+
# Concrete timestepper implementations
137141
export ForwardEuler
138142
include("timesteppers/forward_euler.jl")
139143
export Heun
140144
include("timesteppers/heun.jl")
141145

142-
# model integrator/simulation types and methods
143-
export ModelIntegrator, initialize, current_time, iteration
144-
include("timesteppers/model_integrator.jl")
145-
146146
end # module Terrarium

src/timesteppers/abstract_timestepper.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,20 @@ Return `true` if the given time stepper is adaptive, false otherwise.
3232
function is_adaptive end
3333

3434
"""
35-
timestep!(state, timestepper::AbstractTimeStepper, model::AbstractModel, inputs::InputSources, Δt)
35+
timestep!(integrator::ModelIntegrator, timestepper::AbstractTimeStepper, Δt)
3636
37-
Advance prognostic variables by one time step based on the current state, or by `Δt` units of time.
37+
Advance prognostic variables of the `integrator` model by one time step based on the current state, or by `Δt` units of time.
3838
"""
3939
function timestep! end
4040

41+
"""
42+
timestep!(state, model::AbstractModel, timestepper::AbstractTimeStepper, Δt)
43+
44+
Apply any necessary corrections or model-specific time stepping logic after applying `timestepper` to the prognostic state
45+
variables defined by `model`.
46+
"""
47+
timestep!(state, model::AbstractModel, timestepper::AbstractTimeStepper, Δt) = nothing
48+
4149
"""
4250
initialize(::AbstractTimeStepper, model, state) where {NF}
4351

src/timesteppers/forward_euler.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ is_adaptive(euler::ForwardEuler) = false
1616

1717
is_initialized(euler::ForwardEuler) = true
1818

19-
function timestep!(state, timestepper::ForwardEuler, model::AbstractModel, inputs::InputSources, Δt)
19+
function timestep!(integrator::ModelIntegrator, timestepper::ForwardEuler, Δt)
20+
# Compute auxiliaries and tendencies
21+
update_state!(integrator, compute_tendencies = true)
2022
# Euler step
21-
explicit_step!(state, get_grid(model), timestepper, Δt)
23+
explicit_step!(integrator.state, get_grid(integrator.model), timestepper, Δt)
24+
# Call timestep! on model
25+
timestep!(integrator.state, integrator.model, timestepper, Δt)
2226
# Apply closure relations
23-
closure!(state, model)
27+
closure!(integrator.state, integrator.model)
2428
# Update clock
25-
return tick!(state.clock, Δt)
29+
tick!(integrator.state.clock, Δt)
30+
return nothing
2631
end

src/timesteppers/heun.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,35 @@ function average_tendencies!(
3131
return
3232
end
3333

34-
function timestep!(state, timestepper::Heun, model::AbstractModel, inputs::InputSources, Δt = default_dt(timestepper))
34+
function timestep!(integrator::ModelIntegrator, timestepper::Heun, Δt = default_dt(timestepper))
3535
@assert is_initialized(timestepper)
3636

37+
(; model, state, inputs) = integrator
3738
grid = get_grid(model)
3839

40+
# Update current state
41+
update_state!(state, model, inputs, compute_tendencies = true)
42+
3943
# Copy current state to stage
4044
stage = timestepper.stage
4145
copyto!(stage, state)
4246

4347
# Compute stage
4448
explicit_step!(stage, grid, timestepper, Δt)
49+
# Call timestep! on model
50+
timestep!(stage, model, timestepper, Δt)
4551
# Apply closure relations
4652
closure!(stage, model)
4753
# Update clock
4854
tick!(stage.clock, Δt)
55+
# Recompute tendencies after timestep
4956
update_state!(stage, model, inputs, compute_tendencies = true)
5057

5158
# Final improved Euler step call that steps `state` forward but averages `state.tendencies`
5259
average_tendencies!(state, stage)
5360
explicit_step!(state, grid, timestepper, Δt)
61+
# Call timestep! on model
62+
timestep!(state, model, timestepper, Δt)
5463
# Apply closure relations
5564
closure!(state, model)
5665
# Update clock

src/timesteppers/model_integrator.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ variables.
123123
"""
124124
timestep!(integrator::ModelIntegrator; finalize = true) = timestep!(integrator, default_dt(timestepper(integrator)); finalize)
125125
function timestep!(integrator::ModelIntegrator, Δt; finalize = true)
126-
update_state!(integrator, compute_tendencies = true)
127-
timestep!(integrator.state, integrator.timestepper, integrator.model, integrator.inputs, convert_dt(Δt))
126+
timestep!(integrator, integrator.timestepper, convert_dt(Δt))
128127
if finalize
129128
compute_auxiliary!(integrator.state, integrator.model)
130129
end

test/differentiability/soil_energy_diff.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ function build_soil_energy_model(arch, ::Type{NF}) where {NF}
1414
return model
1515
end
1616

17-
function mean_soil_temperature_step!(state, timestepper, model, inputs, Δt)
18-
timestep!(state, timestepper, model, inputs, Δt)
19-
return mean(interior(state.temperature))
17+
function mean_soil_temperature_step!(integrator, timestepper, Δt)
18+
19+
timestep!(integrator, timestepper, Δt)
20+
return mean(interior(integrator.state.temperature))
2021
# TODO: Figure out why this is segfaulting in Enzyme
2122
# Answer: Average operator is not type inferrable, see:
2223
# https://github.com/CliMA/Oceananigans.jl/issues/4869
@@ -67,11 +68,9 @@ end
6768
@testset "Soil energy model: timestep!" begin
6869
model = build_soil_energy_model(CPU(), Float64)
6970
integrator = initialize(model, ForwardEuler())
70-
inputs = integrator.inputs
71-
state = integrator.state
72-
dstate = make_zero(state)
71+
dintegrator = make_zero(integrator)
7372
stepper = integrator.timestepper
7473
dstepper = make_zero(stepper)
75-
@time Enzyme.autodiff(set_runtime_activity(Reverse), mean_soil_temperature_step!, Active, Duplicated(state, dstate), Duplicated(stepper, dstepper), Const(model), Const(integrator.inputs), Const(integrator.timestepper.Δt))
76-
@test all(isfinite.(dstate.temperature))
74+
@time Enzyme.autodiff(set_runtime_activity(Reverse), mean_soil_temperature_step!, Active, Duplicated(integrator, dintegrator), Duplicated(stepper, dstepper), Const(integrator.timestepper.Δt))
75+
@test all(isfinite.(dintegrator.state.temperature))
7776
end

test/differentiability/soil_hydrology_diff.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,13 @@ end
136136
hydraulic_properties = ConstantSoilHydraulics(Float64)
137137
model = build_soil_energy_hydrology_model(CPU(), Float64; hydraulic_properties)
138138
integrator = initialize(model, ForwardEuler())
139-
inputs = integrator.inputs
140-
state = integrator.state
141-
dstate = make_zero(state)
139+
dintegrator = make_zero(integrator)
140+
# set a seed for the temperature
141+
dintegrator.state.temperature .= 1.0
142142
stepper = integrator.timestepper
143143
dstepper = make_zero(stepper)
144144
Δt = 60.0
145-
@time Enzyme.autodiff(set_runtime_activity(Reverse), timestep!, Const, Duplicated(state, dstate), Duplicated(stepper, dstepper), Const(model), Const(inputs), Const(Δt))
146-
@test all(isfinite.(dstate.temperature))
147-
@test all(isfinite.(dstate.pressure_head))
145+
@time Enzyme.autodiff(set_runtime_activity(Reverse), timestep!, Const, Duplicated(integrator, dintegrator), Duplicated(stepper, dstepper), Const(Δt))
146+
@test all(isfinite.(dintegrator.state.temperature))
147+
@test all(isfinite.(dintegrator.state.pressure_head))
148148
end

test/inputs/input_forcing.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,6 @@ function Terrarium.compute_tendencies!(state, model::TestModel)
2323
return state.tendencies.x .= state.F
2424
end
2525

26-
function Terrarium.timestep!(state, model::TestModel, euler::ForwardEuler, Δt)
27-
Terrarium.compute_tendencies!(state, model)
28-
return @. state.x += Δt * state.tendencies.x
29-
end
30-
3126
@testset "Forcing input" begin
3227
grid = ColumnGrid(CPU(), DEFAULT_NF, ExponentialSpacing())
3328
model = TestModel(grid)

test/timestepping/heun.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ end
4747
dt_heun = default_dt(integrator_heun.timestepper)
4848
@test integrator_heun.state.u[2] == (0.1 * dt_heun + (0.1 * dt_heun + 0.1) * dt_heun) / 2
4949
end
50+
51+
# Use timestep!(state, model, timestepper, Δt) to clip negative values in an super simple example sim
52+
@testset "ExpModel: clip negative values" begin
53+
grid = ColumnGrid(CPU(), Float64, UniformSpacing(N = 1))
54+
model = ExpModel(grid)
55+
56+
Terrarium.timestep!(state, model::ExpModel, timestepper::ForwardEuler, Δt) = begin
57+
state.u[2] = max(state.u[2], 0.0)
58+
end
59+
60+
initializers = (u = -20, v = -5.0)
61+
integrator = initialize(model, ForwardEuler(); initializers)
62+
63+
# Test that timestep! clips negative values
64+
timestep!(integrator)
65+
66+
@test integrator.state.u[2] >= 0.0
67+
end

0 commit comments

Comments
 (0)