@@ -466,13 +466,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
466466 petsc_mat .mult (v_vec , res_vec )
467467 return res
468468 elif isinstance (rhs , matrix .MatrixBase ):
469- petsc_mat = lhs .petscmat
470- (row , col ) = lhs .arguments ()
471- res = tensor .petscmat if tensor else PETSc .Mat ()
472- petsc_mat .matMult (rhs .petscmat , result = res )
473- return tensor if tensor else matrix .AssembledMatrix (expr , self ._bcs , res ,
474- appctx = self ._appctx ,
475- options_prefix = self ._options_prefix )
469+ result = tensor .petscmat if tensor else PETSc .Mat ()
470+ lhs .petsc_mat .matMult (rhs .petscmat , result = result )
471+ if tensor is None :
472+ tensor = self .assembled_matrix (expr , result )
473+ return tensor
476474 else :
477475 raise TypeError ("Incompatible RHS for Action." )
478476 elif isinstance (lhs , (firedrake .Cofunction , firedrake .Function )):
@@ -501,17 +499,18 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
501499 result .dat .maxpy (expr .weights (), [a .dat for a in args ])
502500 return result
503501 elif all (isinstance (op , ufl .Matrix ) for op in args ):
504- res = tensor .petscmat if tensor else PETSc .Mat ()
502+ result = tensor .petscmat if tensor else PETSc .Mat ()
505503 for (op , w ) in zip (args , expr .weights ()):
506- if res :
507- res .axpy (w , op .petscmat )
504+ if result :
505+ # If result is not void, then accumulate on it
506+ result .axpy (w , op .petscmat )
508507 else :
509- # Make a copy to avoid in-place scaling
510- res = op .petscmat .copy ()
511- res .scale (w )
512- return tensor if tensor else matrix . AssembledMatrix ( expr , self . _bcs , res ,
513- appctx = self ._appctx ,
514- options_prefix = self . _options_prefix )
508+ # If result is void, then allocate it with first term
509+ op .petscmat .copy (result = result )
510+ result .scale (w )
511+ if tensor is None :
512+ tensor = self .assembled_matrix ( expr , result )
513+ return tensor
515514 else :
516515 raise TypeError ("Mismatching FormSum shapes" )
517516 elif isinstance (expr , ufl .ExternalOperator ):
@@ -585,9 +584,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
585584 else :
586585 # Copy the interpolation matrix into the output tensor
587586 petsc_mat .copy (result = res )
588- return matrix . AssembledMatrix ( expr . arguments (), self . _bcs , res ,
589- appctx = self ._appctx ,
590- options_prefix = self . _options_prefix )
587+ if tensor is None :
588+ tensor = self .assembled_matrix ( expr , res )
589+ return tensor
591590 else :
592591 # The case rank == 0 is handled via the DAG restructuring
593592 raise ValueError ("Incompatible number of arguments." )
@@ -600,6 +599,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
600599 else :
601600 raise TypeError (f"Unrecognised BaseForm instance: { expr } " )
602601
602+ def assembled_matrix (self , expr , petscmat ):
603+ return matrix .AssembledMatrix (expr .arguments (), self ._bcs , petscmat ,
604+ options_prefix = self ._options_prefix )
605+
603606 @staticmethod
604607 def base_form_postorder_traversal (expr , visitor , visited = {}):
605608 if expr in visited :
0 commit comments