Skip to content

Commit da194e3

Browse files
committed
MatIS: support DirichletBC, add a test
1 parent 4d45c97 commit da194e3

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

pyop2/parloop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ def replace_lgmaps(self):
296296
olgmaps = []
297297
for m, lgmaps in zip(pl_arg.data, pl_arg.lgmaps):
298298
olgmaps.append(m.handle.getLGMap())
299-
m.handle.setLGMap(*lgmaps)
299+
if m.handle.type != "is":
300+
m.handle.setLGMap(*lgmaps)
300301
orig_lgmaps.append(olgmaps)
301302
return tuple(orig_lgmaps)
302303

@@ -309,7 +310,8 @@ def restore_lgmaps(self, orig_lgmaps):
309310
for arg, d in reversed(list(zip(self.global_kernel.arguments, self.arguments))):
310311
if isinstance(arg, (MatKernelArg, MixedMatKernelArg)) and d.lgmaps is not None:
311312
for m, lgmaps in zip(d.data, orig_lgmaps.pop()):
312-
m.handle.setLGMap(*lgmaps)
313+
if m.handle.type != "is":
314+
m.handle.setLGMap(*lgmaps)
313315

314316
@cached_property
315317
def _has_mats(self):

pyop2/types/mat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -825,9 +825,13 @@ def set_local_diagonal_entries(self, rows, diag_val=1.0, idx=None):
825825
rows = rows.reshape(-1, 1)
826826
self.change_assembly_state(Mat.INSERT_VALUES)
827827
if len(rows) > 0:
828-
values = np.full(rows.shape, diag_val, dtype=dtypes.ScalarType)
829-
self.handle.setValuesLocalRCV(rows, rows, values,
830-
addv=PETSc.InsertMode.INSERT_VALUES)
828+
if self.handle.type == "is":
829+
self.handle.assemble()
830+
self.handle.zeroRowsColumnsLocal(rows, diag_val)
831+
else:
832+
values = np.full(rows.shape, diag_val, dtype=dtypes.ScalarType)
833+
self.handle.setValuesLocalRCV(rows, rows, values,
834+
addv=PETSc.InsertMode.INSERT_VALUES)
831835

832836
@mpi.collective
833837
def assemble(self):

tests/firedrake/regression/test_assemble.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,25 @@ def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh):
142142

143143
assert A2.M is A1.M
144144

145+
@pytest.mark.parametrize("dirichlet_bcs", [False, True])
146+
def test_assemble_matis(mesh, dirichlet_bcs):
147+
V = FunctionSpace(mesh, "CG", 1)
148+
u = TrialFunction(V)
149+
v = TestFunction(V)
150+
a = inner(grad(u), grad(v))*dx
151+
if dirichlet_bcs:
152+
bcs = DirichletBC(V, 0, (1, 3))
153+
else:
154+
bcs = None
155+
156+
ais = assemble(a, bcs=bcs, mat_type="is").petscmat
157+
aijnew = PETSc.Mat()
158+
ais.convert("aij", aijnew)
159+
160+
aij = assemble(a, bcs=bcs, mat_type="aij").petscmat
161+
aij.axpy(-1, aijnew)
162+
aij.view()
163+
assert np.allclose(aij[:, :], 0)
145164

146165
def test_assemble_diagonal(mesh):
147166
V = FunctionSpace(mesh, "P", 3)

0 commit comments

Comments
 (0)