Skip to content

Commit d1411cb

Browse files
authored
ImplicitMatrixContext: create submatrix via MatCreateSubMatrixVirtual (#4693)
* ImplicitMatrixContext: create submatrix via MatCreateSubMatrixVirtual
1 parent 395b8ea commit d1411cb

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

firedrake/matrix_free/operators.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,20 @@ def createSubMatrix(self, mat, row_is, col_is, target=None):
377377
row_ises = self._y.function_space().dof_dset.field_ises
378378
col_ises = self._x.function_space().dof_dset.field_ises
379379

380-
row_inds = find_sub_block(row_is, row_ises, comm=self.comm)
381-
if row_is == col_is and row_ises == col_ises:
382-
col_inds = row_inds
383-
else:
384-
col_inds = find_sub_block(col_is, col_ises, comm=self.comm)
380+
try:
381+
row_inds = find_sub_block(row_is, row_ises, comm=self.comm)
382+
if row_is == col_is and row_ises == col_ises:
383+
col_inds = row_inds
384+
else:
385+
col_inds = find_sub_block(col_is, col_ises, comm=self.comm)
386+
except LookupError:
387+
# Attemping to extract a submatrix that does not match with a subfield.
388+
# Use default PETSc implementation (MatCreateSubMatrixVirtual) via MATSHELL instead.
389+
popmethod = self.createSubMatrix
390+
self.createSubMatrix = None
391+
submat = mat.createSubMatrix(row_is, col_is)
392+
self.createSubMatrix = popmethod
393+
return submat
385394

386395
splitter = ExtractSubBlock()
387396
asub = splitter.split(self.a,

tests/firedrake/regression/test_matrix_free.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,45 @@ def test_matrix_free_fieldsplit_with_real():
366366
}}
367367
stokes_solver = LinearVariationalSolver(stokes_problem, solver_parameters=opts)
368368
stokes_solver.solve()
369+
370+
371+
@pytest.mark.parametrize("shape", ["scalar", "mixed"])
372+
def test_sub_matrix_not_subfield(shape):
373+
mesh = UnitSquareMesh(2, 2)
374+
if shape == "mixed":
375+
V = VectorFunctionSpace(mesh, "CG", 2)
376+
Q = FunctionSpace(mesh, "CG", 1)
377+
Z = V * Q
378+
u, p = TrialFunctions(Z)
379+
v, q = TestFunctions(Z)
380+
a = inner(grad(u), grad(v)) * dx - inner(p, div(v))*dx - inner(div(u), q)*dx
381+
bcs = DirichletBC(Z.sub(0), 0, (1, 3))
382+
383+
elif shape == "scalar":
384+
V = FunctionSpace(mesh, "CG", 1)
385+
u = TrialFunction(V)
386+
v = TestFunction(V)
387+
a = inner(grad(u), grad(v)) * dx
388+
bcs = DirichletBC(V, 0, (1, 3))
389+
390+
args = a.arguments()
391+
rows = PETSc.IS().createGeneral(range(0, args[0].function_space().dim(), 2))
392+
cols = PETSc.IS().createGeneral(range(1, args[1].function_space().dim(), 2))
393+
394+
A = assemble(a, bcs=bcs, mat_type="matfree")
395+
Amat = A.petscmat
396+
Asub = Amat.createSubMatrix(rows, cols)
397+
x, y = Asub.createVecs()
398+
399+
m, n = Asub.getSize()
400+
Asub_dense = np.zeros((m, n))
401+
for i in range(n):
402+
x.set(0.0)
403+
x[i] = 1.0
404+
Asub.mult(x, y)
405+
Asub_dense[:, i] = y[:]
406+
407+
A = assemble(a, bcs=bcs, mat_type="aij")
408+
Amat = A.petscmat
409+
Asub_aij = Amat.createSubMatrix(rows, cols)
410+
assert np.allclose(Asub_aij[:, :], Asub_dense)

0 commit comments

Comments
 (0)