Skip to content

Commit b815048

Browse files
committed
address review comments
1 parent 7ea3dfc commit b815048

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

firedrake/assemble.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
466466
petsc_mat.mult(v_vec, res_vec)
467467
return res
468468
elif isinstance(rhs, matrix.MatrixBase):
469-
petsc_mat = lhs.petscmat
470-
(row, col) = lhs.arguments()
471-
res = tensor.petscmat if tensor else PETSc.Mat()
472-
petsc_mat.matMult(rhs.petscmat, result=res)
473-
return tensor if tensor else matrix.AssembledMatrix(expr, self._bcs, res,
474-
appctx=self._appctx,
475-
options_prefix=self._options_prefix)
469+
result = tensor.petscmat if tensor else PETSc.Mat()
470+
lhs.petsc_mat.matMult(rhs.petscmat, result=result)
471+
if tensor is None:
472+
tensor = self.assembled_matrix(expr, result)
473+
return tensor
476474
else:
477475
raise TypeError("Incompatible RHS for Action.")
478476
elif isinstance(lhs, (firedrake.Cofunction, firedrake.Function)):
@@ -501,17 +499,18 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
501499
result.dat.maxpy(expr.weights(), [a.dat for a in args])
502500
return result
503501
elif all(isinstance(op, ufl.Matrix) for op in args):
504-
res = tensor.petscmat if tensor else PETSc.Mat()
502+
result = tensor.petscmat if tensor else PETSc.Mat()
505503
for (op, w) in zip(args, expr.weights()):
506-
if res:
507-
res.axpy(w, op.petscmat)
504+
if result:
505+
# If result is not void, then accumulate on it
506+
result.axpy(w, op.petscmat)
508507
else:
509-
# Make a copy to avoid in-place scaling
510-
res = op.petscmat.copy()
511-
res.scale(w)
512-
return tensor if tensor else matrix.AssembledMatrix(expr, self._bcs, res,
513-
appctx=self._appctx,
514-
options_prefix=self._options_prefix)
508+
# If result is void, then allocate it with first term
509+
op.petscmat.copy(result=result)
510+
result.scale(w)
511+
if tensor is None:
512+
tensor = self.assembled_matrix(expr, result)
513+
return tensor
515514
else:
516515
raise TypeError("Mismatching FormSum shapes")
517516
elif isinstance(expr, ufl.ExternalOperator):
@@ -585,9 +584,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
585584
else:
586585
# Copy the interpolation matrix into the output tensor
587586
petsc_mat.copy(result=res)
588-
return matrix.AssembledMatrix(expr.arguments(), self._bcs, res,
589-
appctx=self._appctx,
590-
options_prefix=self._options_prefix)
587+
if tensor is None:
588+
tensor = self.assembled_matrix(expr, res)
589+
return tensor
591590
else:
592591
# The case rank == 0 is handled via the DAG restructuring
593592
raise ValueError("Incompatible number of arguments.")
@@ -600,6 +599,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
600599
else:
601600
raise TypeError(f"Unrecognised BaseForm instance: {expr}")
602601

602+
def assembled_matrix(self, expr, petscmat):
603+
return matrix.AssembledMatrix(expr.arguments(), self._bcs, petscmat,
604+
options_prefix=self._options_prefix)
605+
603606
@staticmethod
604607
def base_form_postorder_traversal(expr, visitor, visited={}):
605608
if expr in visited:

firedrake/matrix.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(self, a, bcs, mat_type):
4242
self._analyze_form_arguments()
4343
self._arguments = arguments
4444

45+
if bcs is None:
46+
bcs = ()
4547
self.bcs = bcs
4648
self.comm = test.function_space().comm
4749
self._comm = internal_comm(self.comm, self)
@@ -97,6 +99,12 @@ def __str__(self):
9799
return "assembled %s(a=%s, bcs=%s)" % (type(self).__name__,
98100
self.a, self.bcs)
99101

102+
def __add__(self, other):
103+
if isinstance(other, MatrixBase):
104+
return self.petscmat + other.petscmat
105+
else:
106+
return NotImplemented
107+
100108
def assign(self, val):
101109
"""Set matrix entries."""
102110
if isinstance(val, MatrixBase):
@@ -212,15 +220,11 @@ def __init__(self, a, bcs, petscmat, *args, **kwargs):
212220
super(AssembledMatrix, self).__init__(a, bcs, "assembled")
213221

214222
self.petscmat = petscmat
223+
options_prefix = kwargs.pop("options_prefix")
224+
self.petscmat.setOptionsPrefix(options_prefix)
215225

216226
# this allows call to self.M.handle without a new class
217227
self.M = SimpleNamespace(handle=self.mat())
218228

219229
def mat(self):
220230
return self.petscmat
221-
222-
def __add__(self, other):
223-
if isinstance(other, MatrixBase):
224-
return self.petscmat + other.petscmat
225-
else:
226-
return NotImplemented

0 commit comments

Comments
 (0)