11import logging
22import numpy as np
3- from qmat import QDELTA_GENERATORS
3+ from qmat . qdelta import QDeltaGenerator , QDELTA_GENERATORS
44
55from pySDC .core .errors import ParameterError
6- from pySDC .core .level import Level
76from pySDC .core .collocation import CollBase
87from pySDC .helpers .pysdc_helper import FrozenClass
98
109
10+ # Organize QDeltaGenerator class in dict[type(QDeltaGenerator),set(str)] to retrieve aliases
11+ QDELTA_GENERATORS_ALIASES = {v : set () for v in set (QDELTA_GENERATORS .values ())}
12+ for k , v in QDELTA_GENERATORS .items ():
13+ QDELTA_GENERATORS_ALIASES [v ].add (k )
14+
15+
1116# short helper class to add params as attributes
1217class _Pars (FrozenClass ):
1318 def __init__ (self , pars ):
@@ -82,27 +87,17 @@ def __init__(self, params, level):
8287
8388 self .__level = level
8489 self .parallelizable = False
90+ for name in ["genQI" , "genQE" ]:
91+ if hasattr (self , name ):
92+ delattr (self , name )
8593
86- def getGenerator (self , qd_type ):
87- coll = self .coll
88-
89- generator = QDELTA_GENERATORS [qd_type ](
90- # for algebraic types (LU, ...)
91- Q = coll .generator .Q ,
92- # for MIN in tables, MIN-SR-S ...
93- nNodes = coll .num_nodes ,
94- nodeType = coll .node_type ,
95- quadType = coll .quad_type ,
96- # for time-stepping types, MIN-SR-NS
97- nodes = coll .nodes ,
98- tLeft = coll .tleft ,
99- )
100-
101- return generator
94+ def buildGenerator (self , qdType : str ) -> QDeltaGenerator :
95+ return QDELTA_GENERATORS [qdType ](qGen = self .coll .generator , tLeft = self .coll .tleft )
10296
10397 def get_Qdelta_implicit (self , qd_type , k = None ):
10498 QDmat = np .zeros_like (self .coll .Qmat )
105- self .genQI = self .getGenerator (qd_type )
99+ if not hasattr (self , "genQI" ) or qd_type not in QDELTA_GENERATORS_ALIASES [type (self .genQI )]:
100+ self .genQI : QDeltaGenerator = self .buildGenerator (qd_type )
106101 QDmat [1 :, 1 :] = self .genQI .genCoeffs (k = k )
107102
108103 err_msg = 'Lower triangular matrix expected!'
@@ -114,7 +109,8 @@ def get_Qdelta_implicit(self, qd_type, k=None):
114109 def get_Qdelta_explicit (self , qd_type , k = None ):
115110 coll = self .coll
116111 QDmat = np .zeros (coll .Qmat .shape , dtype = float )
117- self .genQE = self .getGenerator (qd_type )
112+ if not hasattr (self , "genQE" ) or qd_type not in QDELTA_GENERATORS_ALIASES [type (self .genQE )]:
113+ self .genQE : QDeltaGenerator = self .buildGenerator (qd_type )
118114 QDmat [1 :, 1 :], QDmat [1 :, 0 ] = self .genQE .genCoeffs (k = k , dTau = True )
119115
120116 err_msg = 'Strictly lower triangular matrix expected!'
@@ -251,6 +247,8 @@ def level(self, L):
251247 Args:
252248 L (pySDC.Level.level): current level
253249 """
250+ from pySDC .core .level import Level
251+
254252 assert isinstance (L , Level )
255253 self .__level = L
256254
@@ -267,5 +265,9 @@ def updateVariableCoeffs(self, k):
267265 k : int
268266 Index of the sweep (0 for initial sweep, 1 for the first one, ...).
269267 """
270- if self .genQI ._K_DEPENDENT :
271- self .QI = self .get_Qdelta_implicit (qd_type = self .genQI .__class__ .__name__ , k = k )
268+ if hasattr (self , "genQI" ) and self .genQI .isKDependent ():
269+ qdType = type (self .genQI ).__name__
270+ self .QI = self .get_Qdelta_implicit (qdType , k = k )
271+ if hasattr (self , "genQE" ) and self .genQE .isKDependent ():
272+ qdType = type (self .genQE ).__name__
273+ self .QE = self .get_Qdelta_explicit (qdType , k = k )
0 commit comments