11"""Defines the Boussinesq equations."""
22
33from firedrake import inner , dx , div , cross , split , as_vector
4- from firedrake .fml import subject
4+ from firedrake .fml import subject , all_terms
55from gusto .core .labels import (
6- time_derivative , transport , prognostic , linearisation ,
6+ prognostic , linearisation ,
77 pressure_gradient , coriolis , divergence , gravity , incompressible
88)
99from gusto .equations .common_forms import (
1010 advection_form , vector_invariant_form ,
1111 kinetic_energy_form , advection_equation_circulation_form ,
12- linear_advection_form
12+ linear_advection_form , linear_circulation_form ,
13+ linear_vector_invariant_form
1314)
1415from gusto .equations .prognostic_equations import PrognosticEquationSet
1516
@@ -41,7 +42,7 @@ class BoussinesqEquations(PrognosticEquationSet):
4142 def __init__ (self , domain , parameters ,
4243 compressible = True ,
4344 space_names = None ,
44- linearisation_map = 'default' ,
45+ linearisation_map = all_terms ,
4546 u_transport_option = "vector_invariant_form" ,
4647 no_normal_flow_bc_ids = None ,
4748 active_tracers = None ):
@@ -59,9 +60,8 @@ def __init__(self, domain, parameters,
5960 in which case the spaces are taken from the de Rham complex.
6061 linearisation_map (func, optional): a function specifying which
6162 terms in the equation set to linearise. If None is specified
62- then no terms are linearised. Defaults to the string 'default',
63- in which case the linearisation includes time derivatives and
64- scalar transport terms.
63+ then no terms are linearised. Defaults to the FML `all_terms`
64+ function.
6565 u_transport_option (str, optional): specifies the transport term
6666 used for the velocity equation. Supported options are:
6767 'vector_invariant_form', 'vector_advection_form' and
@@ -89,16 +89,6 @@ def __init__(self, domain, parameters,
8989 if active_tracers is None :
9090 active_tracers = []
9191
92- if linearisation_map == 'default' :
93- # Default linearisation is time derivatives, scalar transport,
94- # pressure gradient, gravity and divergence terms
95- # Don't include active tracers
96- linearisation_map = lambda t : \
97- t .get (prognostic ) in ['u' , 'p' , 'b' ] \
98- and (any (t .has_label (time_derivative , pressure_gradient ,
99- divergence , gravity ))
100- or (t .get (prognostic ) not in ['u' , 'p' ] and t .has_label (transport )))
101-
10292 super ().__init__ (field_names , domain , space_names ,
10393 linearisation_map = linearisation_map ,
10494 no_normal_flow_bc_ids = no_normal_flow_bc_ids ,
@@ -109,8 +99,8 @@ def __init__(self, domain, parameters,
10999
110100 w , phi , gamma = self .tests [0 :3 ]
111101 u , p , b = split (self .X )
112- u_trial , p_trial , _ = split (self .trials )
113- _ , p_bar , b_bar = split (self .X_ref )
102+ u_trial , p_trial , b_trial = split (self .trials )[ 0 : 3 ]
103+ u_bar , p_bar , b_bar = split (self .X_ref )[ 0 : 3 ]
114104
115105 # -------------------------------------------------------------------- #
116106 # Time Derivative Terms
@@ -122,28 +112,51 @@ def __init__(self, domain, parameters,
122112 # -------------------------------------------------------------------- #
123113 # Velocity transport term -- depends on formulation
124114 if u_transport_option == "vector_invariant_form" :
125- u_adv = prognostic (vector_invariant_form (domain , w , u , u ), 'u' )
126- elif u_transport_option == "vector_advection_form" :
127- u_adv = prognostic (advection_form (w , u , u ), 'u' )
115+ u_adv = prognostic (vector_invariant_form (self .domain , w , u , u ), 'u' )
116+ # Manually add linearisation, as linearisation cannot handle the
117+ # perp function on the plane / vertical slice
118+ if self .linearisation_map (u_adv .terms [0 ]):
119+ linear_u_adv = linear_vector_invariant_form (self .domain , w , u_trial , u_bar )
120+ u_adv = linearisation (u_adv , linear_u_adv )
121+
128122 elif u_transport_option == "circulation_form" :
123+ # This is different to vector invariant form as the K.E. form
124+ # doesn't have a variable marked as "transporting velocity"
129125 ke_form = prognostic (kinetic_energy_form (w , u , u ), 'u' )
130- u_adv = prognostic (advection_equation_circulation_form (domain , w , u , u ), 'u' ) + ke_form
126+ circ_form = prognostic (advection_equation_circulation_form (self .domain , w , u , u ), 'u' )
127+ # Manually add linearisation, as linearisation cannot handle the
128+ # perp function on the plane / vertical slice
129+ if self .linearisation_map (circ_form .terms [0 ]):
130+ linear_circ_form = linear_circulation_form (self .domain , w , u_trial , u_bar )
131+ circ_form = linearisation (circ_form , linear_circ_form )
132+ u_adv = circ_form + ke_form
133+
134+ elif u_transport_option == "vector_advection_form" :
135+ u_adv = prognostic (advection_form (w , u , u ), 'u' )
136+
131137 else :
132- raise ValueError ("Invalid u_transport_option: %s" % u_transport_option )
138+ raise ValueError ("Invalid u_transport_option: %s" % self . u_transport_option )
133139
134140 # Buoyancy transport
135141 b_adv = prognostic (advection_form (gamma , b , u ), 'b' )
142+
143+ # TODO #651: we should remove this hand-coded linearisation
144+ # currently REXI can't handle generated transport linearisations
136145 if self .linearisation_map (b_adv .terms [0 ]):
137- linear_b_adv = linear_advection_form (gamma , b_bar , u_trial )
146+ linear_b_adv = linear_advection_form (gamma , b_trial , u_trial , b_bar , u_bar )
138147 b_adv = linearisation (b_adv , linear_b_adv )
139148
140149 if compressible :
141150 # Pressure transport
142151 p_adv = prognostic (advection_form (phi , p , u ), 'p' )
152+
153+ # TODO #651: we should remove this hand-coded linearisation
154+ # currently REXI can't handle generated transport linearisations
143155 if self .linearisation_map (p_adv .terms [0 ]):
144- linear_p_adv = linear_advection_form (phi , p_bar , u_trial )
156+ linear_p_adv = linear_advection_form (phi , p_trial , u_trial , p_bar , u_bar )
145157 p_adv = linearisation (p_adv , linear_p_adv )
146158 adv_form = subject (u_adv + p_adv + b_adv , self .X )
159+
147160 else :
148161 adv_form = subject (u_adv + b_adv , self .X )
149162
@@ -193,11 +206,11 @@ def __init__(self, domain, parameters,
193206 # Extra Terms (Coriolis)
194207 # -------------------------------------------------------------------- #
195208 if self .parameters .Omega is not None :
196- # TODO: add linearisation
197209 Omega = as_vector ((0 , 0 , self .parameters .Omega ))
198210 coriolis_form = coriolis (subject (prognostic (
199211 inner (w , cross (2 * Omega , u ))* dx , 'u' ), self .X ))
200212 residual += coriolis_form
213+
201214 # -------------------------------------------------------------------- #
202215 # Linearise equations
203216 # -------------------------------------------------------------------- #
@@ -230,7 +243,7 @@ class LinearBoussinesqEquations(BoussinesqEquations):
230243 def __init__ (self , domain , parameters ,
231244 compressible = True ,
232245 space_names = None ,
233- linearisation_map = 'default' ,
246+ linearisation_map = all_terms ,
234247 u_transport_option = "vector_invariant_form" ,
235248 no_normal_flow_bc_ids = None ,
236249 active_tracers = None ):
@@ -249,8 +262,8 @@ def __init__(self, domain, parameters,
249262 linearisation_map (func, optional): a function specifying which
250263 terms in the equation set to linearise. If None is specified
251264 then no terms are linearised. Defaults to the string 'default',
252- in which case the linearisation includes time derivatives and
253- scalar transport terms .
265+ in which case the linearisation drops terms for any active
266+ tracers .
254267 u_transport_option (str, optional): specifies the transport term
255268 used for the velocity equation. Supported options are:
256269 'vector_invariant_form', 'vector_advection_form' and
@@ -267,13 +280,6 @@ def __init__(self, domain, parameters,
267280 NotImplementedError: active tracers are not implemented.
268281 """
269282
270- if linearisation_map == 'default' :
271- # Default linearisation is time derivatives, pressure gradient,
272- # Coriolis and transport term from depth equation
273- linearisation_map = lambda t : \
274- (any (t .has_label (time_derivative , pressure_gradient , coriolis ,
275- gravity , divergence , incompressible ))
276- or (t .get (prognostic ) in ['p' , 'b' ] and t .has_label (transport )))
277283 super ().__init__ (domain = domain ,
278284 parameters = parameters ,
279285 compressible = compressible ,
0 commit comments