@@ -73,6 +73,60 @@ def test_MIN_SR_FLEX(node_type, quad_type, M):
7373 ), f"Applying FLEX preconditioner does not give nilpotent SDC iteration matrix after { M } iterations! (M={ M } , norm={ nilpotency } )"
7474
7575
76+ @pytest .mark .base
77+ @pytest .mark .parametrize ('imex' , [True , False ])
78+ @pytest .mark .parametrize ('num_nodes' , num_nodes )
79+ def test_FLEX_preconditioner_in_sweepers (imex , num_nodes , MPI = False ):
80+ from pySDC .core .level import Level
81+
82+ if imex :
83+ from pySDC .implementations .problem_classes .TestEquation_0D import test_equation_IMEX as problem_class
84+
85+ if MPI :
86+ from pySDC .implementations .sweeper_classes .imex_1st_order_MPI import imex_1st_order_MPI as sweeper_class
87+ else :
88+ from pySDC .implementations .sweeper_classes .imex_1st_order import imex_1st_order as sweeper_class
89+ else :
90+ from pySDC .implementations .problem_classes .TestEquation_0D import testequation0d as problem_class
91+
92+ if MPI :
93+ from pySDC .implementations .sweeper_classes .generic_implicit_MPI import generic_implicit_MPI as sweeper_class
94+ else :
95+ from pySDC .implementations .sweeper_classes .generic_implicit import generic_implicit as sweeper_class
96+
97+ sweeper_params = {'quad_type' : 'RADAU-RIGHT' , 'num_nodes' : num_nodes , 'QI' : 'MIN-SR-FLEX' , 'QE' : 'PIC' }
98+ if MPI :
99+ from mpi4py import MPI
100+
101+ sweeper_params ['comm' ] = MPI .COMM_WORLD
102+ level_params = {'nsweeps' : num_nodes , 'dt' : 1 }
103+
104+ lvl = Level (problem_class , {}, sweeper_class , sweeper_params , level_params , 0 )
105+
106+ lvl .status .unlocked = True
107+ lvl .u [0 ] = lvl .prob .u_exact (0 )
108+ lvl .status .time = 0
109+
110+ sweep = lvl .sweep
111+ sweep .predict ()
112+
113+ for k in range (1 , level_params ['nsweeps' ] + 1 ):
114+ lvl .status .sweep = k
115+ sweep .update_nodes ()
116+ assert np .allclose (
117+ sweep .QI , sweep .get_Qdelta_implicit (sweeper_params ['QI' ], k )
118+ ), f'Got incorrect FLEX preconditioner in sweep { k } '
119+
120+
121+ @pytest .mark .mpi4py
122+ @pytest .mark .parametrize ('imex' , [True , False ])
123+ @pytest .mark .mpi (ranks = [3 ])
124+ def test_FLEX_preconditioner_in_MPI_sweepers (mpi_ranks , imex ):
125+ from mpi4py import MPI
126+
127+ test_FLEX_preconditioner_in_sweepers (imex , num_nodes = MPI .COMM_WORLD .size , MPI = True )
128+
129+
76130@pytest .mark .base
77131@pytest .mark .parametrize ("node_type" , node_types )
78132@pytest .mark .parametrize ("quad_type" , quad_types )
@@ -161,3 +215,5 @@ def test_PIC(node_type, quad_type, M):
161215
162216 test_LU ('LEGENDRE' , 'RADAU-RIGHT' , 4 )
163217 test_LU ('EQUID' , 'LOBATTO' , 5 )
218+
219+ test_FLEX_preconditioner_in_sweepers (True )
0 commit comments