1818import finat .ufl
1919from firedrake import (extrusion_utils as eutils , matrix , parameters , solving ,
2020 tsfc_interface , utils )
21+ from firedrake .formmanipulation import split_form
2122from firedrake .adjoint_utils import annotate_assemble
2223from firedrake .ufl_expr import extract_unique_domain
2324from firedrake .bcs import DirichletBC , EquationBC , EquationBCSplit
@@ -230,34 +231,36 @@ def assemble(self, tensor=None, current_state=None):
230231 # Only Expr resulting in a Matrix if assembled are BaseFormOperator
231232 if not all (isinstance (op , matrix .AssembledMatrix ) for op in (a , b )):
232233 raise TypeError ('Mismatching Sum shapes' )
233- return get_assembler (ufl .FormSum ((a , 1 ), (b , 1 ))). assemble ( )
234+ return assemble (ufl .FormSum ((a , 1 ), (b , 1 )), tensor = tensor )
234235 elif isinstance (expr , ufl .algebra .Product ):
235236 a , b = expr .ufl_operands
236237 scalar = [e for e in expr .ufl_operands if is_scalar_constant_expression (e )]
237238 if scalar :
238239 base_form = a if a is scalar else b
239240 assembled_mat = assemble (base_form )
240- return get_assembler (ufl .FormSum ((assembled_mat , scalar [0 ]))). assemble ( )
241+ return assemble (ufl .FormSum ((assembled_mat , scalar [0 ])), tensor = tensor )
241242 a , b = [assemble (e ) for e in (a , b )]
242- return get_assembler (ufl .action (a , b )). assemble ( )
243+ return assemble (ufl .action (a , b ), tensor = tensor )
243244 # -- Linear combination of Functions and 1-form BaseFormOperators -- #
244245 # Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*)
245246 # with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators.
246247 else :
247248 base_form_operators = extract_base_form_operators (expr )
248- assembled_bfops = [firedrake .assemble (e ) for e in base_form_operators ]
249249 # Substitute base form operators with their output before examining the expression
250250 # which avoids conflict when determining function space, for example:
251251 # extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1
252252 # instead of V2.
253253 if base_form_operators :
254- expr = ufl .replace (expr , dict (zip (base_form_operators , assembled_bfops )))
255- try :
256- coefficients = ufl .algorithms .extract_coefficients (expr )
257- V , = set (c .function_space () for c in coefficients ) - {None }
258- except ValueError :
259- raise ValueError ("Cannot deduce correct target space from pointwise expression" )
260- return firedrake .Function (V ).assign (expr )
254+ assembled_bfops = {e : firedrake .assemble (e ) for e in base_form_operators }
255+ expr = ufl .replace (expr , assembled_bfops )
256+ if tensor is None :
257+ try :
258+ coefficients = ufl .algorithms .extract_coefficients (expr )
259+ V , = set (c .function_space () for c in coefficients ) - {None }
260+ except ValueError :
261+ raise ValueError ("Cannot deduce correct target space from pointwise expression" )
262+ tensor = firedrake .Function (V )
263+ return tensor .assign (expr )
261264
262265
263266class AbstractFormAssembler (abc .ABC ):
@@ -493,9 +496,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
493496 if all (isinstance (op , numbers .Complex ) for op in args ):
494497 result = sum (weight * arg for weight , arg in zip (expr .weights (), args ))
495498 return tensor .assign (result ) if tensor else result
496- elif all (isinstance (op , firedrake .Cofunction ) for op in args ):
499+ elif (all (isinstance (op , firedrake .Cofunction ) for op in args )
500+ or all (isinstance (op , firedrake .Function ) for op in args )):
497501 V , = set (a .function_space () for a in args )
498- result = tensor if tensor else firedrake .Cofunction (V )
502+ result = tensor if tensor else firedrake .Function (V )
499503 result .dat .maxpy (expr .weights (), [a .dat for a in args ])
500504 return result
501505 elif all (isinstance (op , ufl .Matrix ) for op in args ):
@@ -540,7 +544,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
540544 if assembled_expression :
541545 # Occur in situations such as Interpolate composition
542546 expression = assembled_expression [0 ]
543- expr = expr ._ufl_expr_reconstruct_ (expression , v )
547+
548+ reconstruct_interp = expr ._ufl_expr_reconstruct_
549+ if (v , expression ) != expr .argument_slots ():
550+ expr = reconstruct_interp (expression , v = v )
544551
545552 # Different assembly procedures:
546553 # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix)
@@ -552,27 +559,59 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
552559 # If argument numbers have been swapped => Adjoint.
553560 arg_expression = ufl .algorithms .extract_arguments (expression )
554561 is_adjoint = (arg_expression and arg_expression [0 ].number () == 0 )
562+
563+ # Dual interpolation from mixed source
564+ if is_adjoint and len (expr .function_space ()) > 1 :
565+ cur = 0
566+ sub_expressions = []
567+ components = numpy .reshape (expression , (- 1 ,))
568+ for Vi in expr .function_space ():
569+ sub_expressions .append (ufl .as_tensor (components [cur :cur + Vi .value_size ].reshape (Vi .value_shape )))
570+ cur += Vi .value_size
571+
572+ # Component-split of the primal expression interpolated into the dual argument-split
573+ split_interp = sum (reconstruct_interp (sub_expressions [i ], v = vi ) for (i ,), vi in split_form (v ))
574+ return assemble (split_interp , tensor = tensor )
575+
576+ # Dual interpolation into mixed target
577+ if is_adjoint and len (arg_expression [0 ].function_space ()) > 1 and rank == 1 :
578+ V = arg_expression [0 ].function_space ()
579+ tensor = tensor or firedrake .Cofunction (V .dual ())
580+
581+ # Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
582+ for (i ,), sub_interp in split_form (expr ):
583+ assemble (sub_interp , tensor = tensor .subfunctions [i ])
584+ return tensor
585+
586+ # Get the primal space
587+ V = expr .function_space ()
588+ if is_adjoint or rank == 0 :
589+ V = V .dual ()
590+
555591 # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
556- if not is_adjoint and rank != 1 :
592+ if not is_adjoint and rank == 2 :
557593 _ , v1 = expr .arguments ()
558- expression = ufl .replace (expression , {v1 : firedrake . Argument ( v1 .function_space (), number = 0 , part = v1 . part () )})
594+ expression = ufl .replace (expression , {v1 : v1 .reconstruct ( number = 0 )})
559595 # Get the interpolator
560596 interp_data = expr .interp_data
561597 default_missing_val = interp_data .pop ('default_missing_val' , None )
562- interpolator = firedrake .Interpolator (expression , expr . function_space () , ** interp_data )
598+ interpolator = firedrake .Interpolator (expression , V , ** interp_data )
563599 # Assembly
564- if rank == 1 :
600+ if rank == 0 :
601+ Iu = interpolator ._interpolate (default_missing_val = default_missing_val )
602+ return assemble (ufl .Action (v , Iu ), tensor = tensor )
603+ elif rank == 1 :
565604 # Assembling the action of the Jacobian adjoint.
566605 if is_adjoint :
567- output = tensor or firedrake .Cofunction (arg_expression [0 ].function_space ().dual ())
568- return interpolator ._interpolate (v , output = output , adjoint = True , default_missing_val = default_missing_val )
606+ return interpolator ._interpolate (v , output = tensor , adjoint = True , default_missing_val = default_missing_val )
569607 # Assembling the Jacobian action.
570- if interpolator .nargs :
608+ elif interpolator .nargs :
571609 return interpolator ._interpolate (expression , output = tensor , default_missing_val = default_missing_val )
572610 # Assembling the operator
573- if tensor is None :
611+ elif tensor is None :
574612 return interpolator ._interpolate (default_missing_val = default_missing_val )
575- return firedrake .Interpolator (expression , tensor , ** interp_data )._interpolate (default_missing_val = default_missing_val )
613+ else :
614+ return firedrake .Interpolator (expression , tensor , ** interp_data )._interpolate (default_missing_val = default_missing_val )
576615 elif rank == 2 :
577616 res = tensor .petscmat if tensor else PETSc .Mat ()
578617 # Get the interpolation matrix
@@ -799,7 +838,7 @@ def restructure_base_form(expr, visited=None):
799838 replace_map = {arg : left }
800839 # Decrease number for all the other arguments since the lowest numbered argument will be replaced.
801840 other_args = [a for a in right .arguments () if a is not arg ]
802- new_args = [firedrake . Argument ( a . function_space (), number = a .number ()- 1 , part = a . part () ) for a in other_args ]
841+ new_args = [a . reconstruct ( number = a .number ()- 1 ) for a in other_args ]
803842 replace_map .update (dict (zip (other_args , new_args )))
804843 # Replace arguments
805844 return ufl .replace (right , replace_map )
@@ -810,13 +849,13 @@ def restructure_base_form(expr, visited=None):
810849 u , v = B .arguments ()
811850 # Let V1 and V2 be primal spaces, B: V1 -> V2 and B*: V2* -> V1*:
812851 # Adjoint(B(Argument(V1, 1), Argument(V2.dual(), 0))) = B(Argument(V1, 0), Argument(V2.dual(), 1))
813- reordered_arguments = ( firedrake . Argument ( u . function_space (), number = v .number (), part = v . part ()),
814- firedrake . Argument ( v . function_space (), number = u .number (), part = u . part ()))
852+ reordered_arguments = { u : u . reconstruct ( number = v .number ()),
853+ v : v . reconstruct ( number = u .number ())}
815854 # Replace arguments in argument slots
816- return ufl .replace (B , dict ( zip (( u , v ), reordered_arguments )) )
855+ return ufl .replace (B , reordered_arguments )
817856
818857 # -- Case (5) -- #
819- if isinstance (expr , ufl .core .base_form_operator .BaseFormOperator ) and not expr .arguments ():
858+ if isinstance (expr , ufl .core .base_form_operator .BaseFormOperator ) and len ( expr .arguments ()) == 0 :
820859 # We are assembling a BaseFormOperator of rank 0 (no arguments).
821860 # B(f, u*) be a BaseFormOperator with u* a Cofunction and f a Coefficient, then:
822861 # B(f, u*) <=> Action(B(f, v*), f) where v* is a Coargument
0 commit comments