@@ -95,6 +95,10 @@ def assemble(expr, *args, **kwargs):
9595 `matrix.Matrix`.
9696 is_base_form_preprocessed : bool
9797 If `True`, skip preprocessing of the form.
98+ current_state : firedrake.function.Function or None
99+ If provided and ``zero_bc_nodes == False``, the boundary condition
100+ nodes of the output are set to the residual of the boundary conditions
101+ computed as ``current_state`` minus the boundary condition value.
98102
99103 Returns
100104 -------
@@ -130,16 +134,21 @@ def assemble(expr, *args, **kwargs):
130134 """
131135 if args :
132136 raise RuntimeError (f"Got unexpected args: { args } " )
133- tensor = kwargs .pop ("tensor" , None )
134- return get_assembler (expr , * args , ** kwargs ).assemble (tensor = tensor )
137+
138+ assemble_kwargs = {}
139+ for key in ("tensor" , "current_state" ):
140+ if key in kwargs :
141+ assemble_kwargs [key ] = kwargs .pop (key , None )
142+ return get_assembler (expr , * args , ** kwargs ).assemble (** assemble_kwargs )
135143
136144
137145def get_assembler (form , * args , ** kwargs ):
138146 """Create an assembler.
139147
140148 Notes
141149 -----
142- See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
150+ See `assemble` for descriptions of the parameters. ``tensor`` and
151+ ``current_state`` should not be passed to this function.
143152
144153 """
145154 is_base_form_preprocessed = kwargs .pop ('is_base_form_preprocessed' , False )
@@ -187,13 +196,15 @@ class ExprAssembler(object):
187196 def __init__ (self , expr ):
188197 self ._expr = expr
189198
190- def assemble (self , tensor = None ):
199+ def assemble (self , tensor = None , current_state = None ):
191200 """Assemble the pointwise expression.
192201
193202 Parameters
194203 ----------
195204 tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
196205 Output tensor.
206+ current_state : None
207+ Ignored by this class.
197208
198209 Returns
199210 -------
@@ -205,6 +216,7 @@ def assemble(self, tensor=None):
205216 from ufl .checks import is_scalar_constant_expression
206217
207218 assert tensor is None
219+ assert current_state is None
208220 expr = self ._expr
209221 # Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
210222 base_form_operators = extract_base_form_operators (expr )
@@ -274,13 +286,16 @@ def allocate(self):
274286 """Allocate memory for the output tensor."""
275287
276288 @abc .abstractmethod
277- def assemble (self , tensor = None ):
289+ def assemble (self , tensor = None , current_state = None ):
278290 """Assemble the form.
279291
280292 Parameters
281293 ----------
282294 tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
283295 Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
296+ current_state : firedrake.function.Function or None
297+ If provided, the boundary condition nodes are set to the boundary condition residual
298+ computed as ``current_state`` minus the boundary condition value.
284299
285300 Returns
286301 -------
@@ -358,13 +373,16 @@ def allocation_integral_types(self):
358373 else :
359374 return self ._allocation_integral_types
360375
361- def assemble (self , tensor = None ):
376+ def assemble (self , tensor = None , current_state = None ):
362377 """Assemble the form.
363378
364379 Parameters
365380 ----------
366381 tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
367382 Output tensor to contain the result of assembly.
383+ current_state : firedrake.function.Function or None
384+ If provided, the boundary condition nodes are set to the boundary condition residual
385+ computed as ``current_state`` minus the boundary condition value.
368386
369387 Returns
370388 -------
@@ -389,7 +407,7 @@ def visitor(e, *operands):
389407 rank = len (self ._form .arguments ())
390408 if rank == 1 and not isinstance (result , ufl .ZeroBaseForm ):
391409 for bc in self ._bcs :
392- bc . zero ( result )
410+ OneFormAssembler . _apply_bc ( self , result , bc , u = current_state )
393411
394412 if tensor :
395413 BaseFormAssembler .update_tensor (result , tensor )
@@ -968,13 +986,16 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
968986 super ().__init__ (form , bcs = bcs , form_compiler_parameters = form_compiler_parameters )
969987 self ._needs_zeroing = needs_zeroing
970988
971- def assemble (self , tensor = None ):
989+ def assemble (self , tensor = None , current_state = None ):
972990 """Assemble the form.
973991
974992 Parameters
975993 ----------
976994 tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
977995 Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
996+ current_state : firedrake.function.Function or None
997+ If provided, the boundary condition nodes are set to the boundary condition residual
998+ computed as ``current_state`` minus the boundary condition value.
978999
9791000 Returns
9801001 -------
@@ -998,12 +1019,12 @@ def assemble(self, tensor=None):
9981019 self .execute_parloops (tensor )
9991020
10001021 for bc in self ._bcs :
1001- self ._apply_bc (tensor , bc )
1022+ self ._apply_bc (tensor , bc , u = current_state )
10021023
10031024 return self .result (tensor )
10041025
10051026 @abc .abstractmethod
1006- def _apply_bc (self , tensor , bc ):
1027+ def _apply_bc (self , tensor , bc , u = None ):
10071028 """Apply boundary condition."""
10081029
10091030 @abc .abstractmethod
@@ -1138,7 +1159,7 @@ def allocate(self):
11381159 comm = self ._form .ufl_domains ()[0 ]._comm
11391160 )
11401161
1141- def _apply_bc (self , tensor , bc ):
1162+ def _apply_bc (self , tensor , bc , u = None ):
11421163 pass
11431164
11441165 def _check_tensor (self , tensor ):
@@ -1199,26 +1220,29 @@ def allocate(self):
11991220 else :
12001221 raise RuntimeError (f"Not expected: found rank = { rank } and diagonal = { self ._diagonal } " )
12011222
1202- def _apply_bc (self , tensor , bc ):
1223+ def _apply_bc (self , tensor , bc , u = None ):
12031224 # TODO Maybe this could be a singledispatchmethod?
12041225 if isinstance (bc , DirichletBC ):
1205- self ._apply_dirichlet_bc (tensor , bc )
1226+ if self ._diagonal :
1227+ bc .set (tensor , self ._weight )
1228+ elif self ._zero_bc_nodes :
1229+ bc .zero (tensor )
1230+ else :
1231+ # The residual belongs to a mixed space that is dual on the boundary nodes
1232+ # and primal on the interior nodes. Therefore, this is a type-safe operation.
1233+ r = tensor .riesz_representation ("l2" )
1234+ bc .apply (r , u = u )
12061235 elif isinstance (bc , EquationBCSplit ):
12071236 bc .zero (tensor )
1208- type (self )(bc .f , bcs = bc .bcs , form_compiler_parameters = self ._form_compiler_params , needs_zeroing = False ,
1209- zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self ._weight ).assemble (tensor = tensor )
1237+ OneFormAssembler (bc .f , bcs = bc .bcs ,
1238+ form_compiler_parameters = self ._form_compiler_params ,
1239+ needs_zeroing = False ,
1240+ zero_bc_nodes = self ._zero_bc_nodes ,
1241+ diagonal = self ._diagonal ,
1242+ weight = self ._weight ).assemble (tensor = tensor , current_state = u )
12101243 else :
12111244 raise AssertionError
12121245
1213- def _apply_dirichlet_bc (self , tensor , bc ):
1214- if self ._diagonal :
1215- bc .set (tensor , self ._weight )
1216- elif not self ._zero_bc_nodes :
1217- # NOTE this only works if tensor is a Function and not a Cofunction
1218- bc .apply (tensor )
1219- else :
1220- bc .zero (tensor )
1221-
12221246 def _check_tensor (self , tensor ):
12231247 if tensor .function_space () != self ._form .arguments ()[0 ].function_space ().dual ():
12241248 raise ValueError ("Form's argument does not match provided result tensor" )
@@ -1430,7 +1454,8 @@ def _all_assemblers(self):
14301454 all_assemblers .extend (_assembler ._all_assemblers )
14311455 return tuple (all_assemblers )
14321456
1433- def _apply_bc (self , tensor , bc ):
1457+ def _apply_bc (self , tensor , bc , u = None ):
1458+ assert u is None
14341459 op2tensor = tensor .M
14351460 spaces = tuple (a .function_space () for a in tensor .a .arguments ())
14361461 V = bc .function_space ()
@@ -1534,7 +1559,7 @@ def allocate(self):
15341559 options_prefix = self ._options_prefix ,
15351560 appctx = self ._appctx or {})
15361561
1537- def assemble (self , tensor = None ):
1562+ def assemble (self , tensor = None , current_state = None ):
15381563 if tensor is None :
15391564 tensor = self .allocate ()
15401565 else :
0 commit comments