diff --git a/pySDC/core/sweeper.py b/pySDC/core/sweeper.py index 43e734aa31..1db9eaf356 100644 --- a/pySDC/core/sweeper.py +++ b/pySDC/core/sweeper.py @@ -83,30 +83,27 @@ def __init__(self, params, level): self.__level = level self.parallelizable = False - def setupGenerator(self, qd_type): + def getGenerator(self, qd_type): coll = self.coll - try: - assert QDELTA_GENERATORS[qd_type] == type(self.generator) - assert self.generator.QDelta.shape[0] == coll.Qmat.shape[0] - 1 - except (AssertionError, AttributeError): - self.generator = QDELTA_GENERATORS[qd_type]( - # for algebraic types (LU, ...) - Q=coll.generator.Q, - # for MIN in tables, MIN-SR-S ... - nNodes=coll.num_nodes, - nodeType=coll.node_type, - quadType=coll.quad_type, - # for time-stepping types, MIN-SR-NS - nodes=coll.nodes, - tLeft=coll.tleft, - ) - except Exception as e: - raise ValueError(f"could not generate {qd_type=!r} with qmat, got error : {e}") from e + + generator = QDELTA_GENERATORS[qd_type]( + # for algebraic types (LU, ...) + Q=coll.generator.Q, + # for MIN in tables, MIN-SR-S ... + nNodes=coll.num_nodes, + nodeType=coll.node_type, + quadType=coll.quad_type, + # for time-stepping types, MIN-SR-NS + nodes=coll.nodes, + tLeft=coll.tleft, + ) + + return generator def get_Qdelta_implicit(self, qd_type, k=None): QDmat = np.zeros_like(self.coll.Qmat) - self.setupGenerator(qd_type) - QDmat[1:, 1:] = self.generator.genCoeffs(k=k) + self.genQI = self.getGenerator(qd_type) + QDmat[1:, 1:] = self.genQI.genCoeffs(k=k) err_msg = 'Lower triangular matrix expected!' np.testing.assert_array_equal(np.triu(QDmat, k=1), np.zeros(QDmat.shape), err_msg=err_msg) @@ -117,8 +114,8 @@ def get_Qdelta_implicit(self, qd_type, k=None): def get_Qdelta_explicit(self, qd_type, k=None): coll = self.coll QDmat = np.zeros(coll.Qmat.shape, dtype=float) - self.setupGenerator(qd_type) - QDmat[1:, 1:], QDmat[1:, 0] = self.generator.genCoeffs(k=k, dTau=True) + self.genQE = self.getGenerator(qd_type) + QDmat[1:, 1:], QDmat[1:, 0] = self.genQE.genCoeffs(k=k, dTau=True) err_msg = 'Strictly lower triangular matrix expected!' np.testing.assert_array_equal(np.triu(QDmat, k=0), np.zeros(QDmat.shape), err_msg=err_msg) @@ -270,5 +267,5 @@ def updateVariableCoeffs(self, k): k : int Index of the sweep (0 for initial sweep, 1 for the first one, ...). """ - if self.params.QI == 'MIN-SR-FLEX': - self.QI = self.get_Qdelta_implicit(qd_type="MIN-SR-FLEX", k=k) + if self.genQI._K_DEPENDENT: + self.QI = self.get_Qdelta_implicit(qd_type=self.genQI.__class__.__name__, k=k)