@@ -19,15 +19,15 @@ def test_MIN_SR(node_type, quad_type, M):
1919
2020 # Check non-stiff limit
2121 QDelta = sweeper .get_Qdelta_implicit ('MIN-SR-NS' )[1 :, 1 :]
22- assert np .all (np .diag (np .diag (QDelta )) == QDelta ), "no diagonal QDelta "
22+ assert np .all (np .diag (np .diag (QDelta )) == QDelta ), "QDelta not diagonal "
2323 K = Q - QDelta
2424 Km = np .linalg .matrix_power (K , M )
2525 nilpotency = np .linalg .norm (Km , ord = np .inf )
2626 assert nilpotency < 1e-10 , "Q-QDelta not nilpotent " f"(M={ M } , norm={ nilpotency } )"
2727
2828 # Check stiff limit
2929 QDelta = sweeper .get_Qdelta_implicit ('MIN-SR-S' )[1 :, 1 :]
30- assert np .all (np .diag (np .diag (QDelta )) == QDelta ), "no diagonal QDelta "
30+ assert np .all (np .diag (np .diag (QDelta )) == QDelta ), "QDelta not diagonal "
3131
3232 if params ['quad_type' ] in ['LOBATTO' , 'RADAU-LEFT' ]:
3333 QDelta = np .diag (1 / np .diag (QDelta [1 :, 1 :]))
@@ -41,6 +41,92 @@ def test_MIN_SR(node_type, quad_type, M):
4141 assert nilpotency < 1e-10 , "I-QDelta^{-1}Q not nilpotent " f"(M={ M } , norm={ nilpotency } )"
4242
4343
44+ @pytest .mark .base
45+ @pytest .mark .parametrize ("node_type" , node_types )
46+ @pytest .mark .parametrize ("quad_type" , quad_types )
47+ @pytest .mark .parametrize ("M" , num_nodes )
48+ def test_MIN_SR_FLEX (node_type , quad_type , M ):
49+ params = {'num_nodes' : M , 'quad_type' : quad_type , 'node_type' : node_type }
50+ sweeper = Sweeper (params )
51+
52+ start_idx = 1
53+ for i in range (M ):
54+ if sweeper .coll .nodes [i ] == 0 :
55+ start_idx += 1
56+ else :
57+ break
58+
59+ Q = sweeper .coll .Qmat [start_idx :, start_idx :]
60+
61+ QDelta = [sweeper .get_Qdelta_implicit ('MIN-SR-FLEX' , k = i + 1 )[start_idx :, start_idx :] for i in range (M )]
62+ for QD in QDelta :
63+ assert np .all (np .diag (np .diag (QD )) == QD ), "QDelta not diagonal"
64+
65+ I = np .eye (M + 1 - start_idx )
66+ K = np .eye (M + 1 - start_idx )
67+ for QD in QDelta :
68+ K = (I - np .linalg .inv (QD ) @ Q ) @ K
69+
70+ nilpotency = np .linalg .norm (K , ord = np .inf )
71+ assert (
72+ nilpotency < 1e-10
73+ ), f"Applying FLEX preconditioner does not give nilpotent SDC iteration matrix after { M } iterations! (M={ M } , norm={ nilpotency } )"
74+
75+
76+ @pytest .mark .base
77+ @pytest .mark .parametrize ('imex' , [True , False ])
78+ @pytest .mark .parametrize ('num_nodes' , num_nodes )
79+ def test_FLEX_preconditioner_in_sweepers (imex , num_nodes , MPI = False ):
80+ from pySDC .core .level import Level
81+
82+ if imex :
83+ from pySDC .implementations .problem_classes .TestEquation_0D import test_equation_IMEX as problem_class
84+
85+ if MPI :
86+ from pySDC .implementations .sweeper_classes .imex_1st_order_MPI import imex_1st_order_MPI as sweeper_class
87+ else :
88+ from pySDC .implementations .sweeper_classes .imex_1st_order import imex_1st_order as sweeper_class
89+ else :
90+ from pySDC .implementations .problem_classes .TestEquation_0D import testequation0d as problem_class
91+
92+ if MPI :
93+ from pySDC .implementations .sweeper_classes .generic_implicit_MPI import generic_implicit_MPI as sweeper_class
94+ else :
95+ from pySDC .implementations .sweeper_classes .generic_implicit import generic_implicit as sweeper_class
96+
97+ sweeper_params = {'quad_type' : 'RADAU-RIGHT' , 'num_nodes' : num_nodes , 'QI' : 'MIN-SR-FLEX' , 'QE' : 'PIC' }
98+ if MPI :
99+ from mpi4py import MPI
100+
101+ sweeper_params ['comm' ] = MPI .COMM_WORLD
102+ level_params = {'nsweeps' : num_nodes , 'dt' : 1 }
103+
104+ lvl = Level (problem_class , {}, sweeper_class , sweeper_params , level_params , 0 )
105+
106+ lvl .status .unlocked = True
107+ lvl .u [0 ] = lvl .prob .u_exact (0 )
108+ lvl .status .time = 0
109+
110+ sweep = lvl .sweep
111+ sweep .predict ()
112+
113+ for k in range (1 , level_params ['nsweeps' ] + 1 ):
114+ lvl .status .sweep = k
115+ sweep .update_nodes ()
116+ assert np .allclose (
117+ sweep .QI , sweep .get_Qdelta_implicit (sweeper_params ['QI' ], k )
118+ ), f'Got incorrect FLEX preconditioner in sweep { k } '
119+
120+
121+ @pytest .mark .mpi4py
122+ @pytest .mark .parametrize ('imex' , [True , False ])
123+ @pytest .mark .mpi (ranks = [3 ])
124+ def test_FLEX_preconditioner_in_MPI_sweepers (mpi_ranks , imex ):
125+ from mpi4py import MPI
126+
127+ test_FLEX_preconditioner_in_sweepers (imex , num_nodes = MPI .COMM_WORLD .size , MPI = True )
128+
129+
44130@pytest .mark .base
45131@pytest .mark .parametrize ("node_type" , node_types )
46132@pytest .mark .parametrize ("quad_type" , quad_types )
@@ -122,8 +208,12 @@ def test_PIC(node_type, quad_type, M):
122208
123209
124210if __name__ == '__main__' :
211+ test_MIN_SR_FLEX ('LEGENDRE' , 'LOBATTO' , 4 )
212+
125213 test_MIN_SR ('LEGENDRE' , 'RADAU-RIGHT' , 4 )
126214 test_MIN_SR ('EQUID' , 'LOBATTO' , 5 )
127215
128216 test_LU ('LEGENDRE' , 'RADAU-RIGHT' , 4 )
129217 test_LU ('EQUID' , 'LOBATTO' , 5 )
218+
219+ test_FLEX_preconditioner_in_sweepers (True )
0 commit comments