@@ -362,3 +362,63 @@ def test_restricted_function_space_extrusion_stokes(ncells):
362362 # -- Actually, the ordering is the same.
363363 assert np .allclose (sol_res .subfunctions [0 ].dat .data_ro_with_halos , sol .subfunctions [0 ].dat .data_ro_with_halos )
364364 assert np .allclose (sol_res .subfunctions [1 ].dat .data_ro_with_halos , sol .subfunctions [1 ].dat .data_ro_with_halos )
365+
366+
367+ @pytest .mark .parametrize ("names" , [(None , None ), (None , "name1" ), ("name0" , "name1" )])
368+ def test_restrict_fieldsplit (names ):
369+ mesh = UnitSquareMesh (2 , 2 )
370+ V = FunctionSpace (mesh , "CG" , 1 , name = names [0 ])
371+ Q = FunctionSpace (mesh , "CG" , 2 , name = names [1 ])
372+ Z = V * Q
373+
374+ z = Function (Z )
375+ test = TestFunction (Z )
376+ z_exact = Constant ([1 , - 1 ])
377+
378+ F = inner (z - z_exact , test ) * dx
379+ bcs = [DirichletBC (Z .sub (i ), z_exact [i ], (i + 1 , i + 3 )) for i in range (len (Z ))]
380+
381+ problem = NonlinearVariationalProblem (F , z , bcs = bcs , restrict = True )
382+ solver = NonlinearVariationalSolver (problem , solver_parameters = {
383+ "ksp_type" : "preonly" ,
384+ "pc_type" : "fieldsplit" ,
385+ "pc_fieldsplit_type" : "additive" ,
386+ f"fieldsplit_{ names [0 ] or 0 } _pc_type" : "lu" ,
387+ f"fieldsplit_{ names [1 ] or 1 } _pc_type" : "lu" },
388+ options_prefix = "" )
389+ solver .solve ()
390+
391+ # Test prefixes for the restricted spaces
392+ pc = solver .snes .ksp .pc
393+ for field , ksp in enumerate (pc .getFieldSplitSubKSP ()):
394+ name = Z [field ].name or field
395+ assert ksp .getOptionsPrefix () == f"fieldsplit_{ name } _"
396+
397+ assert errornorm (z_exact [0 ], z .subfunctions [0 ]) < 1E-10
398+ assert errornorm (z_exact [1 ], z .subfunctions [1 ]) < 1E-10
399+
400+
401+ def test_restrict_python_pc ():
402+ mesh = UnitSquareMesh (2 , 2 )
403+ x , y = SpatialCoordinate (mesh )
404+ V = FunctionSpace (mesh , "CG" , 1 )
405+
406+ u = Function (V )
407+ test = TestFunction (V )
408+ u_exact = x + y
409+ g = Function (V ).interpolate (u_exact )
410+
411+ F = inner (u - u_exact , test ) * dx
412+ bcs = [DirichletBC (V , g , 1 ), DirichletBC (V , u_exact , 2 )]
413+
414+ problem = NonlinearVariationalProblem (F , u , bcs = bcs , restrict = True )
415+ solver = NonlinearVariationalSolver (problem , solver_parameters = {
416+ "mat_type" : "matfree" ,
417+ "ksp_type" : "preonly" ,
418+ "pc_type" : "python" ,
419+ "pc_python_type" : "firedrake.AssembledPC" ,
420+ "assembled_pc_type" : "lu" },
421+ options_prefix = "" )
422+ solver .solve ()
423+
424+ assert errornorm (u_exact , u ) < 1E-10
0 commit comments