@@ -549,8 +549,24 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
549549 # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint
550550 # This can be generalized to the case where the first slot is an arbitray expression.
551551 rank = len (expr .arguments ())
552- # If argument numbers have been swapped => Adjoint.
553552 arg_expression = ufl .algorithms .extract_arguments (expression )
553+
554+ # Handle interpolation of subfunctions
555+ parent_tensor = None
556+ if isinstance (expression , ufl .classes .Indexed ):
557+ assert rank == 1
558+ A , multiindex = expression .ufl_operands
559+ index , = map (int , multiindex )
560+ # TODO handle more general case with ufl.replace
561+ assert isinstance (A , firedrake .Argument )
562+ V = A .function_space ()
563+ # Symbolic indirection for the input expression
564+ expression = firedrake .Argument (V .sub (index ), number = A .number (), part = A .part ())
565+ # Symbolic indirection for the output tensor
566+ parent_tensor = tensor or firedrake .Cofunction (V .dual ())
567+ tensor = parent_tensor .sub (index )
568+
569+ # If argument numbers have been swapped => Adjoint.
554570 is_adjoint = (arg_expression and arg_expression [0 ].number () == 0 )
555571 # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
556572 if not is_adjoint and rank != 1 :
@@ -559,20 +575,23 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
559575 # Get the interpolator
560576 interp_data = expr .interp_data
561577 default_missing_val = interp_data .pop ('default_missing_val' , None )
578+
562579 interpolator = firedrake .Interpolator (expression , expr .function_space (), ** interp_data )
563580 # Assembly
564581 if rank == 1 :
565582 # Assembling the action of the Jacobian adjoint.
566583 if is_adjoint :
567584 output = tensor or firedrake .Cofunction (arg_expression [0 ].function_space ().dual ())
568- return interpolator ._interpolate (v , output = output , adjoint = True , default_missing_val = default_missing_val )
585+ result = interpolator ._interpolate (v , output = output , adjoint = True , default_missing_val = default_missing_val )
569586 # Assembling the Jacobian action.
570- if interpolator .nargs :
571- return interpolator ._interpolate (expression , output = tensor , default_missing_val = default_missing_val )
587+ elif interpolator .nargs :
588+ result = interpolator ._interpolate (expression , output = tensor , default_missing_val = default_missing_val )
572589 # Assembling the operator
573- if tensor is None :
574- return interpolator ._interpolate (default_missing_val = default_missing_val )
575- return firedrake .Interpolator (expression , tensor , ** interp_data )._interpolate (default_missing_val = default_missing_val )
590+ elif tensor is None :
591+ result = interpolator ._interpolate (default_missing_val = default_missing_val )
592+ else :
593+ result = firedrake .Interpolator (expression , tensor , ** interp_data )._interpolate (default_missing_val = default_missing_val )
594+ return parent_tensor or result
576595 elif rank == 2 :
577596 res = tensor .petscmat if tensor else PETSc .Mat ()
578597 # Get the interpolation matrix
0 commit comments