Skip to content

Commit 1b20212

Browse files
committed
add tests
1 parent c55b648 commit 1b20212

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tests/firedrake/regression/test_restricted_function_space.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,65 @@ def test_restricted_function_space_extrusion_stokes(ncells):
362362
# -- Actually, the ordering is the same.
363363
assert np.allclose(sol_res.subfunctions[0].dat.data_ro_with_halos, sol.subfunctions[0].dat.data_ro_with_halos)
364364
assert np.allclose(sol_res.subfunctions[1].dat.data_ro_with_halos, sol.subfunctions[1].dat.data_ro_with_halos)
365+
366+
367+
@pytest.mark.parametrize("names", [(None, None), (None, "name1"), ("name0", "name1")])
368+
def test_restrict_fieldsplit(names):
369+
mesh = UnitSquareMesh(2, 2)
370+
V = FunctionSpace(mesh, "CG", 1, name=names[0])
371+
Q = FunctionSpace(mesh, "CG", 2, name=names[1])
372+
Z = V * Q
373+
374+
z = Function(Z)
375+
test = TestFunction(Z)
376+
z_exact = Constant([1, -1])
377+
378+
F = inner(z - z_exact, test) * dx
379+
bcs = [DirichletBC(Z.sub(i), z_exact[i], (i+1, i+3)) for i in range(len(Z))]
380+
381+
problem = NonlinearVariationalProblem(F, z, bcs=bcs, restrict=True)
382+
solver = NonlinearVariationalSolver(problem, solver_parameters={
383+
"snes_type": "ksponly",
384+
"ksp_type": "preonly",
385+
"pc_type": "fieldsplit",
386+
"pc_fieldsplit_type": "additive",
387+
f"fieldsplit_{names[0] or 0}_pc_type": "lu",
388+
f"fieldsplit_{names[1] or 1}_pc_type": "lu"},
389+
options_prefix="")
390+
solver.solve()
391+
392+
# Test prefixes for the restricted spaces
393+
pc = solver.snes.ksp.pc
394+
for field, ksp in enumerate(pc.getFieldSplitSubKSP()):
395+
name = Z[field].name or field
396+
assert ksp.getOptionsPrefix() == f"fieldsplit_{name}_"
397+
398+
assert errornorm(z_exact[0], z.subfunctions[0]) < 1E-10
399+
assert errornorm(z_exact[1], z.subfunctions[1]) < 1E-10
400+
401+
402+
def test_restrict_python_pc():
403+
mesh = UnitSquareMesh(2, 2)
404+
x, y = SpatialCoordinate(mesh)
405+
V = FunctionSpace(mesh, "CG", 1)
406+
407+
u = Function(V)
408+
test = TestFunction(V)
409+
u_exact = x + y
410+
g = Function(V).interpolate(u_exact)
411+
412+
F = inner(u - u_exact, test) * dx
413+
bcs = [DirichletBC(V, g, 1), DirichletBC(V, u_exact, 2)]
414+
415+
problem = NonlinearVariationalProblem(F, u, bcs=bcs, restrict=True)
416+
solver = NonlinearVariationalSolver(problem, solver_parameters={
417+
"snes_type": "ksponly",
418+
"mat_type": "matfree",
419+
"ksp_type": "preonly",
420+
"pc_type": "python",
421+
"pc_python_type": "firedrake.AssembledPC",
422+
"assembled_pc_type": "lu"},
423+
options_prefix="")
424+
solver.solve()
425+
426+
assert errornorm(u_exact, u) < 1E-10

0 commit comments

Comments
 (0)