Skip to content

Commit a64be62

Browse files
authored
Generate linear solver using linearisation of equations (#649)
1 parent 59869d0 commit a64be62

File tree

8 files changed

+543
-221
lines changed

8 files changed

+543
-221
lines changed

examples/shallow_water/williamson_5.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,14 @@ def williamson_5(
7373
rsq = min_value(R0**2, (lamda - lamda_c)**2 + (phi - phi_c)**2)
7474
r = sqrt(rsq)
7575
tpexpr = mountain_height * (1 - r/R0)
76-
eqns = ShallowWaterEquations(domain, parameters, fexpr=fexpr,
77-
topog_expr=tpexpr)
76+
eqns = ShallowWaterEquations(
77+
domain, parameters, fexpr=fexpr, topog_expr=tpexpr
78+
)
7879

7980
# I/O
8081
output = OutputParameters(
81-
dirname=dirname, dumplist_latlon=['D'], dumpfreq=dumpfreq,
82-
dump_vtus=True, dump_nc=False, dumplist=['D', 'topography']
82+
dirname=dirname, dumpfreq=dumpfreq,
83+
dump_vtus=True, dump_nc=True, dumplist=['D', 'topography']
8384
)
8485
diagnostic_fields = [Sum('D', 'topography'), RelativeVorticity(),
8586
MeridionalComponent('u'), ZonalComponent('u')]
@@ -101,7 +102,7 @@ def williamson_5(
101102

102103
# Time stepper
103104
stepper = SemiImplicitQuasiNewton(
104-
eqns, io, transported_fields, transport_methods
105+
eqns, io, transported_fields, transport_methods, tau_values={'D': 1.0}
105106
)
106107

107108
# ------------------------------------------------------------------------ #
17.8 KB
Loading

gusto/equations/boussinesq_equations.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
"""Defines the Boussinesq equations."""
22

33
from firedrake import inner, dx, div, cross, split, as_vector
4-
from firedrake.fml import subject
4+
from firedrake.fml import subject, all_terms
55
from gusto.core.labels import (
6-
time_derivative, transport, prognostic, linearisation,
6+
prognostic, linearisation,
77
pressure_gradient, coriolis, divergence, gravity, incompressible
88
)
99
from gusto.equations.common_forms import (
1010
advection_form, vector_invariant_form,
1111
kinetic_energy_form, advection_equation_circulation_form,
12-
linear_advection_form
12+
linear_advection_form, linear_circulation_form,
13+
linear_vector_invariant_form
1314
)
1415
from gusto.equations.prognostic_equations import PrognosticEquationSet
1516

@@ -41,7 +42,7 @@ class BoussinesqEquations(PrognosticEquationSet):
4142
def __init__(self, domain, parameters,
4243
compressible=True,
4344
space_names=None,
44-
linearisation_map='default',
45+
linearisation_map=all_terms,
4546
u_transport_option="vector_invariant_form",
4647
no_normal_flow_bc_ids=None,
4748
active_tracers=None):
@@ -59,9 +60,8 @@ def __init__(self, domain, parameters,
5960
in which case the spaces are taken from the de Rham complex.
6061
linearisation_map (func, optional): a function specifying which
6162
terms in the equation set to linearise. If None is specified
62-
then no terms are linearised. Defaults to the string 'default',
63-
in which case the linearisation includes time derivatives and
64-
scalar transport terms.
63+
then no terms are linearised. Defaults to the FML `all_terms`
64+
function.
6565
u_transport_option (str, optional): specifies the transport term
6666
used for the velocity equation. Supported options are:
6767
'vector_invariant_form', 'vector_advection_form' and
@@ -89,16 +89,6 @@ def __init__(self, domain, parameters,
8989
if active_tracers is None:
9090
active_tracers = []
9191

92-
if linearisation_map == 'default':
93-
# Default linearisation is time derivatives, scalar transport,
94-
# pressure gradient, gravity and divergence terms
95-
# Don't include active tracers
96-
linearisation_map = lambda t: \
97-
t.get(prognostic) in ['u', 'p', 'b'] \
98-
and (any(t.has_label(time_derivative, pressure_gradient,
99-
divergence, gravity))
100-
or (t.get(prognostic) not in ['u', 'p'] and t.has_label(transport)))
101-
10292
super().__init__(field_names, domain, space_names,
10393
linearisation_map=linearisation_map,
10494
no_normal_flow_bc_ids=no_normal_flow_bc_ids,
@@ -109,8 +99,8 @@ def __init__(self, domain, parameters,
10999

110100
w, phi, gamma = self.tests[0:3]
111101
u, p, b = split(self.X)
112-
u_trial, p_trial, _ = split(self.trials)
113-
_, p_bar, b_bar = split(self.X_ref)
102+
u_trial, p_trial, b_trial = split(self.trials)[0:3]
103+
u_bar, p_bar, b_bar = split(self.X_ref)[0:3]
114104

115105
# -------------------------------------------------------------------- #
116106
# Time Derivative Terms
@@ -122,28 +112,51 @@ def __init__(self, domain, parameters,
122112
# -------------------------------------------------------------------- #
123113
# Velocity transport term -- depends on formulation
124114
if u_transport_option == "vector_invariant_form":
125-
u_adv = prognostic(vector_invariant_form(domain, w, u, u), 'u')
126-
elif u_transport_option == "vector_advection_form":
127-
u_adv = prognostic(advection_form(w, u, u), 'u')
115+
u_adv = prognostic(vector_invariant_form(self.domain, w, u, u), 'u')
116+
# Manually add linearisation, as linearisation cannot handle the
117+
# perp function on the plane / vertical slice
118+
if self.linearisation_map(u_adv.terms[0]):
119+
linear_u_adv = linear_vector_invariant_form(self.domain, w, u_trial, u_bar)
120+
u_adv = linearisation(u_adv, linear_u_adv)
121+
128122
elif u_transport_option == "circulation_form":
123+
# This is different to vector invariant form as the K.E. form
124+
# doesn't have a variable marked as "transporting velocity"
129125
ke_form = prognostic(kinetic_energy_form(w, u, u), 'u')
130-
u_adv = prognostic(advection_equation_circulation_form(domain, w, u, u), 'u') + ke_form
126+
circ_form = prognostic(advection_equation_circulation_form(self.domain, w, u, u), 'u')
127+
# Manually add linearisation, as linearisation cannot handle the
128+
# perp function on the plane / vertical slice
129+
if self.linearisation_map(circ_form.terms[0]):
130+
linear_circ_form = linear_circulation_form(self.domain, w, u_trial, u_bar)
131+
circ_form = linearisation(circ_form, linear_circ_form)
132+
u_adv = circ_form + ke_form
133+
134+
elif u_transport_option == "vector_advection_form":
135+
u_adv = prognostic(advection_form(w, u, u), 'u')
136+
131137
else:
132-
raise ValueError("Invalid u_transport_option: %s" % u_transport_option)
138+
raise ValueError("Invalid u_transport_option: %s" % self.u_transport_option)
133139

134140
# Buoyancy transport
135141
b_adv = prognostic(advection_form(gamma, b, u), 'b')
142+
143+
# TODO #651: we should remove this hand-coded linearisation
144+
# currently REXI can't handle generated transport linearisations
136145
if self.linearisation_map(b_adv.terms[0]):
137-
linear_b_adv = linear_advection_form(gamma, b_bar, u_trial)
146+
linear_b_adv = linear_advection_form(gamma, b_trial, u_trial, b_bar, u_bar)
138147
b_adv = linearisation(b_adv, linear_b_adv)
139148

140149
if compressible:
141150
# Pressure transport
142151
p_adv = prognostic(advection_form(phi, p, u), 'p')
152+
153+
# TODO #651: we should remove this hand-coded linearisation
154+
# currently REXI can't handle generated transport linearisations
143155
if self.linearisation_map(p_adv.terms[0]):
144-
linear_p_adv = linear_advection_form(phi, p_bar, u_trial)
156+
linear_p_adv = linear_advection_form(phi, p_trial, u_trial, p_bar, u_bar)
145157
p_adv = linearisation(p_adv, linear_p_adv)
146158
adv_form = subject(u_adv + p_adv + b_adv, self.X)
159+
147160
else:
148161
adv_form = subject(u_adv + b_adv, self.X)
149162

@@ -193,11 +206,11 @@ def __init__(self, domain, parameters,
193206
# Extra Terms (Coriolis)
194207
# -------------------------------------------------------------------- #
195208
if self.parameters.Omega is not None:
196-
# TODO: add linearisation
197209
Omega = as_vector((0, 0, self.parameters.Omega))
198210
coriolis_form = coriolis(subject(prognostic(
199211
inner(w, cross(2*Omega, u))*dx, 'u'), self.X))
200212
residual += coriolis_form
213+
201214
# -------------------------------------------------------------------- #
202215
# Linearise equations
203216
# -------------------------------------------------------------------- #
@@ -230,7 +243,7 @@ class LinearBoussinesqEquations(BoussinesqEquations):
230243
def __init__(self, domain, parameters,
231244
compressible=True,
232245
space_names=None,
233-
linearisation_map='default',
246+
linearisation_map=all_terms,
234247
u_transport_option="vector_invariant_form",
235248
no_normal_flow_bc_ids=None,
236249
active_tracers=None):
@@ -249,8 +262,8 @@ def __init__(self, domain, parameters,
249262
linearisation_map (func, optional): a function specifying which
250263
terms in the equation set to linearise. If None is specified
251264
then no terms are linearised. Defaults to the string 'default',
252-
in which case the linearisation includes time derivatives and
253-
scalar transport terms.
265+
in which case the linearisation drops terms for any active
266+
tracers.
254267
u_transport_option (str, optional): specifies the transport term
255268
used for the velocity equation. Supported options are:
256269
'vector_invariant_form', 'vector_advection_form' and
@@ -267,13 +280,6 @@ def __init__(self, domain, parameters,
267280
NotImplementedError: active tracers are not implemented.
268281
"""
269282

270-
if linearisation_map == 'default':
271-
# Default linearisation is time derivatives, pressure gradient,
272-
# Coriolis and transport term from depth equation
273-
linearisation_map = lambda t: \
274-
(any(t.has_label(time_derivative, pressure_gradient, coriolis,
275-
gravity, divergence, incompressible))
276-
or (t.get(prognostic) in ['p', 'b'] and t.has_label(transport)))
277283
super().__init__(domain=domain,
278284
parameters=parameters,
279285
compressible=compressible,

gusto/equations/common_forms.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
__all__ = ["advection_form", "advection_form_1d", "continuity_form",
1414
"continuity_form_1d", "vector_invariant_form",
15+
"linear_vector_invariant_form",
1516
"kinetic_energy_form", "advection_equation_circulation_form",
17+
"linear_circulation_form",
1618
"diffusion_form", "diffusion_form_1d",
1719
"linear_advection_form", "linear_continuity_form",
1820
"split_continuity_form", "tracer_conservative_form", "split_hv_advective_form"]
@@ -60,21 +62,25 @@ def advection_form_1d(test, q, ubar):
6062
return transport(form, TransportEquationType.advective)
6163

6264

63-
def linear_advection_form(test, qbar, ubar):
65+
def linear_advection_form(test, q, u, qbar, ubar):
6466
"""
6567
The form corresponding to the linearised advective transport operator.
6668
6769
Args:
6870
test (:class:`TestFunction`): the test function.
69-
qbar (:class:`ufl.Expr`): the variable to be transported.
70-
ubar (:class:`ufl.Expr`): the transporting velocity.
71+
q (:class:`ufl.Expr`): the perturbation variable to be transported.
72+
u (:class:`ufl.Expr`): the perturbation transporting velocity.
73+
qbar (:class:`ufl.Expr`): the mean variable to be transported.
74+
ubar (:class:`ufl.Expr`): the mean transporting velocity.
7175
7276
Returns:
7377
:class:`LabelledForm`: a labelled transport form.
7478
"""
7579

76-
L = test*dot(ubar, grad(qbar))*dx
77-
form = transporting_velocity(L, ubar)
80+
form = (
81+
transporting_velocity(test*dot(ubar, grad(q))*dx, ubar)
82+
+ transporting_velocity(test*dot(u, grad(qbar))*dx, u)
83+
)
7884

7985
return transport(form, TransportEquationType.advective)
8086

@@ -121,21 +127,25 @@ def continuity_form_1d(test, q, ubar):
121127
return transport(form, TransportEquationType.conservative)
122128

123129

124-
def linear_continuity_form(test, qbar, ubar):
130+
def linear_continuity_form(test, q, u, qbar, ubar):
125131
"""
126132
The form corresponding to the linearised continuity transport operator.
127133
128134
Args:
129135
test (:class:`TestFunction`): the test function.
130-
qbar (:class:`ufl.Expr`): the variable to be transported.
131-
ubar (:class:`ufl.Expr`): the transporting velocity.
136+
q (:class:`ufl.Expr`): the perturbation variable to be transported.
137+
u (:class:`ufl.Expr`): the perturbation transporting velocity.
138+
qbar (:class:`ufl.Expr`): the mean variable to be transported.
139+
ubar (:class:`ufl.Expr`): the mean transporting velocity.
132140
133141
Returns:
134142
:class:`LabelledForm`: a labelled transport form.
135143
"""
136144

137-
L = test*div(qbar*ubar)*dx
138-
form = transporting_velocity(L, ubar)
145+
form = (
146+
transporting_velocity(test*div(q*ubar)*dx, ubar)
147+
+ transporting_velocity(test*div(qbar*u)*dx, u)
148+
)
139149

140150
return transport(form, TransportEquationType.conservative)
141151

@@ -172,6 +182,34 @@ def vector_invariant_form(domain, test, q, ubar):
172182
return transport(form, TransportEquationType.vector_invariant)
173183

174184

185+
def linear_vector_invariant_form(domain, test, q, ubar):
186+
u"""
187+
The linear form corresponding to the vector invariant transport operator.
188+
189+
The vector invariant transport operator is: (∇×q)×u + (1/2)∇(u.q)
190+
and its linearised form is:
191+
(∇×q')×u_bar + (∇×q_bar)×u' + (1/2)∇(u_bar.q') + (1/2)∇(u'.q_bar)
192+
193+
Args:
194+
domain (:class:`Domain`): the model's domain object, containing the
195+
mesh and the compatible function spaces.
196+
test (:class:`TestFunction`): the test function.
197+
q (:class:`ufl.Expr`): the variable to be transported.
198+
ubar (:class:`ufl.Expr`): the transporting velocity.
199+
200+
Returns:
201+
class:`LabelledForm`: a labelled transport form.
202+
"""
203+
204+
L = linear_circulation_form(domain, test, q, ubar).form
205+
206+
# Add K.E. term
207+
L -= div(test)*inner(q, ubar)*dx
208+
form = transporting_velocity(L, ubar)
209+
210+
return transport(form, TransportEquationType.vector_invariant)
211+
212+
175213
def kinetic_energy_form(test, q, ubar):
176214
u"""
177215
The form corresponding to the kinetic energy term.
@@ -210,6 +248,7 @@ def advection_equation_circulation_form(domain, test, q, ubar):
210248
term.
211249
212250
Args:
251+
domain (:class:`Domain`): the model's domain object.
213252
test (:class:`TestFunction`): the test function.
214253
q (:class:`ufl.Expr`): the variable to be transported.
215254
ubar (:class:`ufl.Expr`): the transporting velocity.
@@ -230,6 +269,26 @@ def advection_equation_circulation_form(domain, test, q, ubar):
230269
return transport(form, TransportEquationType.circulation)
231270

232271

272+
def linear_circulation_form(domain, test, q, ubar):
273+
"""
274+
The linear circulation term in the transport of a vector-valued field.
275+
276+
Args:
277+
test (:class:`TestFunction`): the test function.
278+
q (:class:`ufl.Expr`): the variable to be transported.
279+
ubar (:class:`ufl.Expr`): the transporting velocity.
280+
281+
Returns:
282+
class:`LabelledForm`: a labelled transport form.
283+
"""
284+
285+
form = (
286+
advection_equation_circulation_form(domain, test, q, ubar)
287+
+ advection_equation_circulation_form(domain, test, ubar, q)
288+
)
289+
return form
290+
291+
233292
def diffusion_form(test, q, kappa):
234293
u"""
235294
The diffusion form, -∇.(κ∇q) for diffusivity κ and variable q.

0 commit comments

Comments
 (0)