@@ -83,30 +83,27 @@ def __init__(self, params, level):
8383 self .__level = level
8484 self .parallelizable = False
8585
86- def setupGenerator (self , qd_type ):
86+ def getGenerator (self , qd_type ):
8787 coll = self .coll
88- try :
89- assert QDELTA_GENERATORS [qd_type ] == type (self .generator )
90- assert self .generator .QDelta .shape [0 ] == coll .Qmat .shape [0 ] - 1
91- except (AssertionError , AttributeError ):
92- self .generator = QDELTA_GENERATORS [qd_type ](
93- # for algebraic types (LU, ...)
94- Q = coll .generator .Q ,
95- # for MIN in tables, MIN-SR-S ...
96- nNodes = coll .num_nodes ,
97- nodeType = coll .node_type ,
98- quadType = coll .quad_type ,
99- # for time-stepping types, MIN-SR-NS
100- nodes = coll .nodes ,
101- tLeft = coll .tleft ,
102- )
103- except Exception as e :
104- raise ValueError (f"could not generate { qd_type = !r} with qmat, got error : { e } " ) from e
88+
89+ generator = QDELTA_GENERATORS [qd_type ](
90+ # for algebraic types (LU, ...)
91+ Q = coll .generator .Q ,
92+ # for MIN in tables, MIN-SR-S ...
93+ nNodes = coll .num_nodes ,
94+ nodeType = coll .node_type ,
95+ quadType = coll .quad_type ,
96+ # for time-stepping types, MIN-SR-NS
97+ nodes = coll .nodes ,
98+ tLeft = coll .tleft ,
99+ )
100+
101+ return generator
105102
106103 def get_Qdelta_implicit (self , qd_type , k = None ):
107104 QDmat = np .zeros_like (self .coll .Qmat )
108- self .setupGenerator (qd_type )
109- QDmat [1 :, 1 :] = self .generator .genCoeffs (k = k )
105+ self .genQI = self . getGenerator (qd_type )
106+ QDmat [1 :, 1 :] = self .genQI .genCoeffs (k = k )
110107
111108 err_msg = 'Lower triangular matrix expected!'
112109 np .testing .assert_array_equal (np .triu (QDmat , k = 1 ), np .zeros (QDmat .shape ), err_msg = err_msg )
@@ -117,8 +114,8 @@ def get_Qdelta_implicit(self, qd_type, k=None):
117114 def get_Qdelta_explicit (self , qd_type , k = None ):
118115 coll = self .coll
119116 QDmat = np .zeros (coll .Qmat .shape , dtype = float )
120- self .setupGenerator (qd_type )
121- QDmat [1 :, 1 :], QDmat [1 :, 0 ] = self .generator .genCoeffs (k = k , dTau = True )
117+ self .genQE = self . getGenerator (qd_type )
118+ QDmat [1 :, 1 :], QDmat [1 :, 0 ] = self .genQE .genCoeffs (k = k , dTau = True )
122119
123120 err_msg = 'Strictly lower triangular matrix expected!'
124121 np .testing .assert_array_equal (np .triu (QDmat , k = 0 ), np .zeros (QDmat .shape ), err_msg = err_msg )
@@ -270,5 +267,5 @@ def updateVariableCoeffs(self, k):
270267 k : int
271268 Index of the sweep (0 for initial sweep, 1 for the first one, ...).
272269 """
273- if self .params . QI == 'MIN-SR-FLEX' :
274- self .QI = self .get_Qdelta_implicit (qd_type = "MIN-SR-FLEX" , k = k )
270+ if self .genQI . _K_DEPENDENT :
271+ self .QI = self .get_Qdelta_implicit (qd_type = self . genQI . __class__ . __name__ , k = k )
0 commit comments