Skip to content

Commit 5a3f2a5

Browse files
Witt-Djshipton
andauthored
Dwitt/tr bdf2 (#643)
Co-authored-by: jshipton <j.shipton@exeter.ac.uk>
1 parent a64be62 commit 5a3f2a5

File tree

6 files changed

+771
-32
lines changed

6 files changed

+771
-32
lines changed

examples/compressible_euler/skamarock_klemp_nonhydrostatic.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
import itertools
1919
from firedrake import (
2020
as_vector, SpatialCoordinate, PeriodicIntervalMesh, ExtrudedMesh, exp, sin,
21-
Function, pi, COMM_WORLD
21+
Function, pi, COMM_WORLD, sqrt
2222
)
2323
import numpy as np
2424
from gusto import (
25-
Domain, IO, OutputParameters, SemiImplicitQuasiNewton, SSPRK3, DGUpwind,
26-
logger, SUPGOptions, Perturbation, CompressibleParameters,
25+
Domain, IO, OutputParameters, TRBDF2QuasiNewton, SemiImplicitQuasiNewton, SSPRK3,
26+
DGUpwind, logger, SUPGOptions, Perturbation, CompressibleParameters,
2727
CompressibleEulerEquations, HydrostaticCompressibleEulerEquations,
2828
compressible_hydrostatic_balance, RungeKuttaFormulation, CompressibleSolver,
29-
SubcyclingOptions, hydrostatic_parameters
29+
hydrostatic_parameters, SubcyclingOptions,
3030
)
3131

3232
skamarock_klemp_nonhydrostatic_defaults = {
@@ -36,7 +36,8 @@
3636
'tmax': 3000.,
3737
'dumpfreq': 250,
3838
'dirname': 'skamarock_klemp_nonhydrostatic',
39-
'hydrostatic': False
39+
'hydrostatic': False,
40+
'timestepper': 'SIQN'
4041
}
4142

4243

@@ -47,7 +48,8 @@ def skamarock_klemp_nonhydrostatic(
4748
tmax=skamarock_klemp_nonhydrostatic_defaults['tmax'],
4849
dumpfreq=skamarock_klemp_nonhydrostatic_defaults['dumpfreq'],
4950
dirname=skamarock_klemp_nonhydrostatic_defaults['dirname'],
50-
hydrostatic=skamarock_klemp_nonhydrostatic_defaults['hydrostatic']
51+
hydrostatic=skamarock_klemp_nonhydrostatic_defaults['hydrostatic'],
52+
timestepper=skamarock_klemp_nonhydrostatic_defaults['timestepper']
5153
):
5254

5355
# ------------------------------------------------------------------------ #
@@ -98,7 +100,7 @@ def skamarock_klemp_nonhydrostatic(
98100
output = OutputParameters(
99101
dirname=dirname, dumpfreq=dumpfreq, pddumpfreq=dumpfreq,
100102
dump_vtus=False, dump_nc=True,
101-
point_data=[('theta_perturbation', points)],
103+
point_data=[('theta_perturbation', points)]
102104
)
103105
else:
104106
logger.warning(
@@ -115,7 +117,10 @@ def skamarock_klemp_nonhydrostatic(
115117

116118
# Transport schemes
117119
theta_opts = SUPGOptions()
118-
subcycling_options = SubcyclingOptions(subcycle_by_courant=0.25)
120+
if timestepper == 'SIQN':
121+
subcycling_options = SubcyclingOptions(subcycle_by_courant=0.25)
122+
else:
123+
subcycling_options = None
119124
transported_fields = [
120125
SSPRK3(domain, "u", subcycling_options=subcycling_options),
121126
SSPRK3(
@@ -135,6 +140,9 @@ def skamarock_klemp_nonhydrostatic(
135140

136141
# Linear solver
137142
if hydrostatic:
143+
if timestepper == 'TR-BDF2':
144+
raise ValueError('Hydrostatic equations not implmented for TR-BDF2')
145+
138146
linear_solver = CompressibleSolver(
139147
eqns, solver_parameters=hydrostatic_parameters,
140148
overwrite_solver_parameters=True
@@ -143,11 +151,23 @@ def skamarock_klemp_nonhydrostatic(
143151
linear_solver = CompressibleSolver(eqns)
144152

145153
# Time stepper
146-
stepper = SemiImplicitQuasiNewton(
147-
eqns, io, transported_fields, transport_methods,
148-
linear_solver=linear_solver
149-
)
154+
if timestepper == 'TR-BDF2':
155+
gamma = (1-sqrt(2)/2)
156+
gamma2 = (1 - 2*float(gamma))/(2 - 2*float(gamma))
157+
tr_solver = CompressibleSolver(eqns, alpha=gamma)
158+
bdf_solver = CompressibleSolver(eqns, alpha=gamma2)
159+
stepper = TRBDF2QuasiNewton(
160+
eqns, io, transported_fields, transport_methods,
161+
gamma=gamma,
162+
tr_solver=tr_solver,
163+
bdf_solver=bdf_solver
164+
)
150165

166+
elif timestepper == 'SIQN':
167+
stepper = SemiImplicitQuasiNewton(
168+
eqns, io, transported_fields, transport_methods,
169+
linear_solver=linear_solver
170+
)
151171
# ------------------------------------------------------------------------ #
152172
# Initial conditions
153173
# ------------------------------------------------------------------------ #
@@ -248,6 +268,13 @@ def skamarock_klemp_nonhydrostatic(
248268
action="store_true",
249269
default=skamarock_klemp_nonhydrostatic_defaults['hydrostatic']
250270
)
271+
parser.add_argument(
272+
'timestepper',
273+
help='Which time stepper to use, takes SIQN or TR-BDF2',
274+
type=str,
275+
choices=['SIQN', 'TR-BDF2'],
276+
default=skamarock_klemp_nonhydrostatic_defaults['timestepper']
277+
)
251278
args, unknown = parser.parse_known_args()
252279

253280
skamarock_klemp_nonhydrostatic(**vars(args))

examples/compressible_euler/test_compressible_euler_examples.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ def test_skamarock_klemp_nonhydrostatic_parallel():
8383
test_skamarock_klemp_nonhydrostatic()
8484

8585

86+
def test_skamarock_klemp_nonhydrostatic_TR_BDF2():
87+
from skamarock_klemp_nonhydrostatic import skamarock_klemp_nonhydrostatic
88+
test_name = 'skamarock_klemp_nonhydrostatic'
89+
skamarock_klemp_nonhydrostatic(
90+
ncolumns=30,
91+
nlayers=5,
92+
dt=6.0,
93+
tmax=60.0,
94+
dumpfreq=10,
95+
dirname=make_dirname(test_name),
96+
hydrostatic=False,
97+
timestepper='TR-BDF2'
98+
)
99+
100+
101+
@pytest.mark.parallel(nprocs=2)
102+
def test_skamarock_klemp_nonhydrostatic_TR_BDF2_parallel():
103+
test_skamarock_klemp_nonhydrostatic_TR_BDF2()
104+
105+
86106
# Hydrostatic equations not currently working
87107
@pytest.mark.xfail
88108
def test_hyd_switch_skamarock_klemp_nonhydrostatic():

gusto/timestepping/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from gusto.timestepping.timestepper import * # noqa
22
from gusto.timestepping.split_timestepper import * # noqa
3-
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
3+
from gusto.timestepping.semi_implicit_quasi_newton import * # noqa
4+
from gusto.timestepping.tr_bdf2_quasi_newton import * # noqa

gusto/timestepping/semi_implicit_quasi_newton.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from gusto.timestepping.timestepper import BaseTimestepper
2323

2424

25-
__all__ = ["SemiImplicitQuasiNewton"]
25+
__all__ = ["SemiImplicitQuasiNewton", "Forcing"]
2626

2727

2828
class SemiImplicitQuasiNewton(BaseTimestepper):
@@ -532,7 +532,6 @@ def timestep(self):
532532

533533
xrhs.assign(0.) # xrhs is the residual which goes in the linear solve
534534
xrhs_phys.assign(x_after_fast(self.field_name) - xp(self.field_name))
535-
536535
for inner in range(self.num_inner):
537536

538537
# Implicit forcing ---------------------------------------------
@@ -556,7 +555,6 @@ def timestep(self):
556555

557556
# Update xnp1 values for active tracers not included in the linear solve
558557
self.copy_active_tracers(x_after_fast, xnp1)
559-
560558
self._apply_bcs()
561559

562560
for name, scheme in self.auxiliary_schemes:
@@ -611,7 +609,7 @@ class Forcing(object):
611609
semi-implicit time discretisation.
612610
"""
613611

614-
def __init__(self, equation, implicit_terms, alpha):
612+
def __init__(self, equation, implicit_terms, alpha, dt=None):
615613
"""
616614
Args:
617615
equation (:class:`PrognosticEquationSet`): the prognostic equations
@@ -621,10 +619,14 @@ def __init__(self, equation, implicit_terms, alpha):
621619
alpha (:class:`Function`): semi-implicit off-centering factor. An
622620
alpha of 0 corresponds to fully explicit, while a factor of 1
623621
corresponds to fully implicit.
622+
dt (:float): timestep over which to apply forcing, defaults to None
623+
in which case it is taken from the equation class.
624624
"""
625625

626626
self.field_name = equation.field_name
627-
dt = equation.domain.dt
627+
628+
if dt is None:
629+
dt = equation.domain.dt
628630

629631
W = equation.function_space
630632
self.x0 = Function(W)
@@ -645,16 +647,14 @@ def __init__(self, equation, implicit_terms, alpha):
645647
replace_subject(trials),
646648
map_if_false=drop)
647649

648-
# the explicit forms are multiplied by (1-alpha) and moved to the rhs
649-
one_minus_alpha = Function(alpha.function_space(), val=1-alpha)
650-
L_explicit = -one_minus_alpha*dt*residual.label_map(
650+
L_explicit = -(1 - alpha)*dt*residual.label_map(
651651
lambda t:
652652
any(t.has_label(time_derivative, hydrostatic, *implicit_terms,
653653
return_tuple=True)),
654654
drop,
655655
replace_subject(self.x0))
656656

657-
# the implicit forms are multiplied by alpha and moved to the rhs
657+
# the implicit forms are multiplied by implicit scaling and moved to the rhs
658658
L_implicit = -alpha*dt*residual.label_map(
659659
lambda t:
660660
any(t.has_label(
@@ -682,10 +682,11 @@ def __init__(self, equation, implicit_terms, alpha):
682682
drop)
683683

684684
# now we can set up the explicit and implicit problems
685-
explicit_forcing_problem = LinearVariationalProblem(
686-
a.form, L_explicit.form, self.xF, bcs=bcs,
687-
constant_jacobian=True
688-
)
685+
if alpha != 1.0:
686+
explicit_forcing_problem = LinearVariationalProblem(
687+
a.form, L_explicit.form, self.xF, bcs=bcs,
688+
constant_jacobian=True
689+
)
689690

690691
implicit_forcing_problem = LinearVariationalProblem(
691692
a.form, L_implicit.form, self.xF, bcs=bcs,
@@ -695,19 +696,21 @@ def __init__(self, equation, implicit_terms, alpha):
695696
self.solver_parameters = mass_parameters(W, equation.domain.spaces)
696697

697698
self.solvers = {}
698-
self.solvers["explicit"] = LinearVariationalSolver(
699-
explicit_forcing_problem,
700-
solver_parameters=self.solver_parameters,
701-
options_prefix="ExplicitForcingSolver"
702-
)
699+
if alpha != 1.0:
700+
self.solvers["explicit"] = LinearVariationalSolver(
701+
explicit_forcing_problem,
702+
solver_parameters=self.solver_parameters,
703+
options_prefix="ExplicitForcingSolver"
704+
)
703705
self.solvers["implicit"] = LinearVariationalSolver(
704706
implicit_forcing_problem,
705707
solver_parameters=self.solver_parameters,
706708
options_prefix="ImplicitForcingSolver"
707709
)
708710

709711
if logger.isEnabledFor(DEBUG):
710-
self.solvers["explicit"].snes.ksp.setMonitor(logging_ksp_monitor_true_residual)
712+
if alpha != 1.0:
713+
self.solvers["explicit"].snes.ksp.setMonitor(logging_ksp_monitor_true_residual)
711714
self.solvers["implicit"].snes.ksp.setMonitor(logging_ksp_monitor_true_residual)
712715

713716
def apply(self, x_in, x_nl, x_out, label):

0 commit comments

Comments
 (0)