11import logging
22import numpy as np
3- from qmat import QDELTA_GENERATORS
3+ from qmat . qdelta import QDeltaGenerator , QDELTA_GENERATORS
44
55from pySDC .core .errors import ParameterError
66from pySDC .core .level import Level
77from pySDC .core .collocation import CollBase
88from pySDC .helpers .pysdc_helper import FrozenClass
99
1010
11+ # Organize QDeltaGenerator class in dict[type(QDeltaGenerator),set(str)] to retrieve aliases
12+ QDELTA_GENERATORS_ALIASES = {v : [] for v in set (QDELTA_GENERATORS .values ())}
13+ for k , v in QDELTA_GENERATORS .items ():
14+ QDELTA_GENERATORS_ALIASES [v ].add (k )
15+
16+
1117# short helper class to add params as attributes
1218class _Pars (FrozenClass ):
1319 def __init__ (self , pars ):
@@ -83,26 +89,13 @@ def __init__(self, params, level):
8389 self .__level = level
8490 self .parallelizable = False
8591
86- def getGenerator (self , qd_type ):
87- coll = self .coll
88-
89- generator = QDELTA_GENERATORS [qd_type ](
90- # for algebraic types (LU, ...)
91- Q = coll .generator .Q ,
92- # for MIN in tables, MIN-SR-S ...
93- nNodes = coll .num_nodes ,
94- nodeType = coll .node_type ,
95- quadType = coll .quad_type ,
96- # for time-stepping types, MIN-SR-NS
97- nodes = coll .nodes ,
98- tLeft = coll .tleft ,
99- )
100-
101- return generator
92+ def buildGenerator (self , qdType : str ) -> QDeltaGenerator :
93+ return QDELTA_GENERATORS [qdType ](qGen = self .coll .generator , tLeft = self .coll .tleft )
10294
10395 def get_Qdelta_implicit (self , qd_type , k = None ):
10496 QDmat = np .zeros_like (self .coll .Qmat )
105- self .genQI = self .getGenerator (qd_type )
97+ if not hasattr (self , "genQI" ) or qd_type not in QDELTA_GENERATORS_ALIASES [type (self .genQI )]:
98+ self .genQI : QDeltaGenerator = self .buildGenerator (qd_type )
10699 QDmat [1 :, 1 :] = self .genQI .genCoeffs (k = k )
107100
108101 err_msg = 'Lower triangular matrix expected!'
@@ -114,7 +107,8 @@ def get_Qdelta_implicit(self, qd_type, k=None):
114107 def get_Qdelta_explicit (self , qd_type , k = None ):
115108 coll = self .coll
116109 QDmat = np .zeros (coll .Qmat .shape , dtype = float )
117- self .genQE = self .getGenerator (qd_type )
110+ if not hasattr (self , "genQE" ) or qd_type not in QDELTA_GENERATORS_ALIASES [type (self .genQE )]:
111+ self .genQE : QDeltaGenerator = self .buildGenerator (qd_type )
118112 QDmat [1 :, 1 :], QDmat [1 :, 0 ] = self .genQE .genCoeffs (k = k , dTau = True )
119113
120114 err_msg = 'Strictly lower triangular matrix expected!'
@@ -267,5 +261,9 @@ def updateVariableCoeffs(self, k):
267261 k : int
268262 Index of the sweep (0 for initial sweep, 1 for the first one, ...).
269263 """
270- if self .genQI ._K_DEPENDENT :
271- self .QI = self .get_Qdelta_implicit (qd_type = self .genQI .__class__ .__name__ , k = k )
264+ if self .genQI .isKDependent ():
265+ qdType = type (self .genQI ).__name__
266+ self .QI = self .get_Qdelta_implicit (qdType , k = k )
267+ if self .genQE .isKDependent ():
268+ qdType = type (self .genQE ).__name__
269+ self .QE = self .get_Qdelta_explicit (qdType , k = k )
0 commit comments