Skip to content

Commit 495d5a5

Browse files
Allow other name for FLEX preconditioner (#580)
* Allow other name for FLEX preconditioner * Split generators for implicit and explicit preconditioners
1 parent 8e9c149 commit 495d5a5

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

pySDC/core/sweeper.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -83,30 +83,27 @@ def __init__(self, params, level):
8383
self.__level = level
8484
self.parallelizable = False
8585

86-
def setupGenerator(self, qd_type):
86+
def getGenerator(self, qd_type):
8787
coll = self.coll
88-
try:
89-
assert QDELTA_GENERATORS[qd_type] == type(self.generator)
90-
assert self.generator.QDelta.shape[0] == coll.Qmat.shape[0] - 1
91-
except (AssertionError, AttributeError):
92-
self.generator = QDELTA_GENERATORS[qd_type](
93-
# for algebraic types (LU, ...)
94-
Q=coll.generator.Q,
95-
# for MIN in tables, MIN-SR-S ...
96-
nNodes=coll.num_nodes,
97-
nodeType=coll.node_type,
98-
quadType=coll.quad_type,
99-
# for time-stepping types, MIN-SR-NS
100-
nodes=coll.nodes,
101-
tLeft=coll.tleft,
102-
)
103-
except Exception as e:
104-
raise ValueError(f"could not generate {qd_type=!r} with qmat, got error : {e}") from e
88+
89+
generator = QDELTA_GENERATORS[qd_type](
90+
# for algebraic types (LU, ...)
91+
Q=coll.generator.Q,
92+
# for MIN in tables, MIN-SR-S ...
93+
nNodes=coll.num_nodes,
94+
nodeType=coll.node_type,
95+
quadType=coll.quad_type,
96+
# for time-stepping types, MIN-SR-NS
97+
nodes=coll.nodes,
98+
tLeft=coll.tleft,
99+
)
100+
101+
return generator
105102

106103
def get_Qdelta_implicit(self, qd_type, k=None):
107104
QDmat = np.zeros_like(self.coll.Qmat)
108-
self.setupGenerator(qd_type)
109-
QDmat[1:, 1:] = self.generator.genCoeffs(k=k)
105+
self.genQI = self.getGenerator(qd_type)
106+
QDmat[1:, 1:] = self.genQI.genCoeffs(k=k)
110107

111108
err_msg = 'Lower triangular matrix expected!'
112109
np.testing.assert_array_equal(np.triu(QDmat, k=1), np.zeros(QDmat.shape), err_msg=err_msg)
@@ -117,8 +114,8 @@ def get_Qdelta_implicit(self, qd_type, k=None):
117114
def get_Qdelta_explicit(self, qd_type, k=None):
118115
coll = self.coll
119116
QDmat = np.zeros(coll.Qmat.shape, dtype=float)
120-
self.setupGenerator(qd_type)
121-
QDmat[1:, 1:], QDmat[1:, 0] = self.generator.genCoeffs(k=k, dTau=True)
117+
self.genQE = self.getGenerator(qd_type)
118+
QDmat[1:, 1:], QDmat[1:, 0] = self.genQE.genCoeffs(k=k, dTau=True)
122119

123120
err_msg = 'Strictly lower triangular matrix expected!'
124121
np.testing.assert_array_equal(np.triu(QDmat, k=0), np.zeros(QDmat.shape), err_msg=err_msg)
@@ -270,5 +267,5 @@ def updateVariableCoeffs(self, k):
270267
k : int
271268
Index of the sweep (0 for initial sweep, 1 for the first one, ...).
272269
"""
273-
if self.params.QI == 'MIN-SR-FLEX':
274-
self.QI = self.get_Qdelta_implicit(qd_type="MIN-SR-FLEX", k=k)
270+
if self.genQI._K_DEPENDENT:
271+
self.QI = self.get_Qdelta_implicit(qd_type=self.genQI.__class__.__name__, k=k)

0 commit comments

Comments
 (0)