Skip to content

Commit 3811232

Browse files
tlunetbrownbaerchen
authored andcommitted
More genericity for QDelta coefficients (Parallel-in-Time#584)
* TL: more generic and efficients QDelta coefficients computation + some refactoring * TL: trying to solve circular imports * TL: bugfixes * TL: aaaaaah 😭 * TL: missing this
1 parent dcf2f68 commit 3811232

File tree

12 files changed

+49
-50
lines changed

12 files changed

+49
-50
lines changed

pySDC/core/level.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from pySDC.core.sweeper import Sweeper
2+
from pySDC.core.problem import Problem
3+
14
from pySDC.helpers.pysdc_helper import FrozenClass
25

36

@@ -72,8 +75,8 @@ def __init__(self, problem_class, problem_params, sweeper_class, sweeper_params,
7275
self.status = _Status()
7376

7477
# instantiate sweeper, problem and hooks
75-
self.__sweep = sweeper_class(sweeper_params, self)
76-
self.__prob = problem_class(**problem_params)
78+
self.__sweep: Sweeper = sweeper_class(sweeper_params, self)
79+
self.__prob: Problem = problem_class(**problem_params)
7780

7881
# set name
7982
self.level_index = level_index
@@ -119,7 +122,7 @@ def reset_level(self, reset_status=True):
119122
self.tau = [None] * self.sweep.coll.num_nodes
120123

121124
@property
122-
def sweep(self):
125+
def sweep(self) -> Sweeper:
123126
"""
124127
Getter for the sweeper
125128
@@ -129,7 +132,7 @@ def sweep(self):
129132
return self.__sweep
130133

131134
@property
132-
def prob(self):
135+
def prob(self) -> Problem:
133136
"""
134137
Getter for the problem
135138

pySDC/core/step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from pySDC.core import level as levclass
3+
from pySDC.core.level import Level
44
from pySDC.core.base_transfer import BaseTransfer
55
from pySDC.core.errors import ParameterError
66
from pySDC.helpers.pysdc_helper import FrozenClass
@@ -73,7 +73,7 @@ def __init__(self, description):
7373
# empty attributes
7474
self.__transfer_dict = {}
7575
self.base_transfer = None
76-
self.levels = []
76+
self.levels: list[Level] = []
7777
self.__prev = None
7878
self.__next = None
7979

@@ -149,7 +149,7 @@ def __generate_hierarchy(self, descr):
149149

150150
# generate levels, register and connect if needed
151151
for l in range(len(descr_list)):
152-
L = levclass.Level(
152+
L = Level(
153153
problem_class=descr_list[l]['problem_class'],
154154
problem_params=descr_list[l]['problem_params'],
155155
sweeper_class=descr_list[l]['sweeper_class'],

pySDC/core/sweeper.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import logging
22
import numpy as np
3-
from qmat import QDELTA_GENERATORS
3+
from qmat.qdelta import QDeltaGenerator, QDELTA_GENERATORS
44

55
from pySDC.core.errors import ParameterError
6-
from pySDC.core.level import Level
76
from pySDC.core.collocation import CollBase
87
from pySDC.helpers.pysdc_helper import FrozenClass
98

109

10+
# Organize QDeltaGenerator class in dict[type(QDeltaGenerator),set(str)] to retrieve aliases
11+
QDELTA_GENERATORS_ALIASES = {v: set() for v in set(QDELTA_GENERATORS.values())}
12+
for k, v in QDELTA_GENERATORS.items():
13+
QDELTA_GENERATORS_ALIASES[v].add(k)
14+
15+
1116
# short helper class to add params as attributes
1217
class _Pars(FrozenClass):
1318
def __init__(self, pars):
@@ -82,27 +87,17 @@ def __init__(self, params, level):
8287

8388
self.__level = level
8489
self.parallelizable = False
90+
for name in ["genQI", "genQE"]:
91+
if hasattr(self, name):
92+
delattr(self, name)
8593

86-
def getGenerator(self, qd_type):
87-
coll = self.coll
88-
89-
generator = QDELTA_GENERATORS[qd_type](
90-
# for algebraic types (LU, ...)
91-
Q=coll.generator.Q,
92-
# for MIN in tables, MIN-SR-S ...
93-
nNodes=coll.num_nodes,
94-
nodeType=coll.node_type,
95-
quadType=coll.quad_type,
96-
# for time-stepping types, MIN-SR-NS
97-
nodes=coll.nodes,
98-
tLeft=coll.tleft,
99-
)
100-
101-
return generator
94+
def buildGenerator(self, qdType: str) -> QDeltaGenerator:
95+
return QDELTA_GENERATORS[qdType](qGen=self.coll.generator, tLeft=self.coll.tleft)
10296

10397
def get_Qdelta_implicit(self, qd_type, k=None):
10498
QDmat = np.zeros_like(self.coll.Qmat)
105-
self.genQI = self.getGenerator(qd_type)
99+
if not hasattr(self, "genQI") or qd_type not in QDELTA_GENERATORS_ALIASES[type(self.genQI)]:
100+
self.genQI: QDeltaGenerator = self.buildGenerator(qd_type)
106101
QDmat[1:, 1:] = self.genQI.genCoeffs(k=k)
107102

108103
err_msg = 'Lower triangular matrix expected!'
@@ -114,7 +109,8 @@ def get_Qdelta_implicit(self, qd_type, k=None):
114109
def get_Qdelta_explicit(self, qd_type, k=None):
115110
coll = self.coll
116111
QDmat = np.zeros(coll.Qmat.shape, dtype=float)
117-
self.genQE = self.getGenerator(qd_type)
112+
if not hasattr(self, "genQE") or qd_type not in QDELTA_GENERATORS_ALIASES[type(self.genQE)]:
113+
self.genQE: QDeltaGenerator = self.buildGenerator(qd_type)
118114
QDmat[1:, 1:], QDmat[1:, 0] = self.genQE.genCoeffs(k=k, dTau=True)
119115

120116
err_msg = 'Strictly lower triangular matrix expected!'
@@ -251,6 +247,8 @@ def level(self, L):
251247
Args:
252248
L (pySDC.Level.level): current level
253249
"""
250+
from pySDC.core.level import Level
251+
254252
assert isinstance(L, Level)
255253
self.__level = L
256254

@@ -267,5 +265,9 @@ def updateVariableCoeffs(self, k):
267265
k : int
268266
Index of the sweep (0 for initial sweep, 1 for the first one, ...).
269267
"""
270-
if self.genQI._K_DEPENDENT:
271-
self.QI = self.get_Qdelta_implicit(qd_type=self.genQI.__class__.__name__, k=k)
268+
if hasattr(self, "genQI") and self.genQI.isKDependent():
269+
qdType = type(self.genQI).__name__
270+
self.QI = self.get_Qdelta_implicit(qdType, k=k)
271+
if hasattr(self, "genQE") and self.genQE.isKDependent():
272+
qdType = type(self.genQE).__name__
273+
self.QE = self.get_Qdelta_explicit(qdType, k=k)

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, controller_params, description, comm):
2828
super().__init__(controller_params, description, useMPI=True)
2929

3030
# create single step per processor
31-
self.S = Step(description)
31+
self.S: Step = Step(description)
3232

3333
# pass communicator for future use
3434
self.comm = comm
@@ -688,8 +688,11 @@ def it_fine(self, comm, num_procs):
688688

689689
for hook in self.hooks:
690690
hook.pre_sweep(step=self.S, level_number=0)
691+
692+
self.S.levels[0].sweep.updateVariableCoeffs(k + 1) # update QDelta coefficients if variable preconditioner
691693
self.S.levels[0].sweep.update_nodes()
692694
self.S.levels[0].sweep.compute_residual(stage='IT_FINE')
695+
693696
for hook in self.hooks:
694697
hook.post_sweep(step=self.S, level_number=0)
695698

@@ -718,6 +721,7 @@ def it_down(self, comm, num_procs):
718721

719722
for hook in self.hooks:
720723
hook.pre_sweep(step=self.S, level_number=l)
724+
721725
self.S.levels[l].sweep.update_nodes()
722726
self.S.levels[l].sweep.compute_residual(stage='IT_DOWN')
723727
for hook in self.hooks:

pySDC/implementations/controller_classes/controller_nonMPI.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import dill
55

66
from pySDC.core.controller import Controller
7-
from pySDC.core import step as stepclass
7+
from pySDC.core.step import Step
88
from pySDC.core.errors import ControllerError, CommunicationError
99
from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestarting
1010

@@ -32,7 +32,7 @@ def __init__(self, num_procs, controller_params, description):
3232
# call parent's initialization routine
3333
super().__init__(controller_params, description, useMPI=False)
3434

35-
self.MS = [stepclass.Step(description)]
35+
self.MS: list[Step] = [Step(description)]
3636

3737
# try to initialize via dill.copy (much faster for many time-steps)
3838
try:
@@ -42,7 +42,7 @@ def __init__(self, num_procs, controller_params, description):
4242
except (dill.PicklingError, TypeError, ValueError) as error:
4343
self.logger.warning(f'Need to initialize steps separately due to pickling error: {error}')
4444
for _ in range(num_procs - 1):
45-
self.MS.append(stepclass.Step(description))
45+
self.MS.append(Step(description))
4646

4747
self.base_convergence_controllers += [BasicRestarting.get_implementation(useMPI=False)]
4848
for convergence_controller in self.base_convergence_controllers:
@@ -542,7 +542,7 @@ def it_check(self, local_MS_running):
542542
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
543543
C.reset_buffers_nonMPI(self)
544544

545-
def it_fine(self, local_MS_running):
545+
def it_fine(self, local_MS_running: list[Step]):
546546
"""
547547
Fine sweeps
548548
@@ -567,8 +567,11 @@ def it_fine(self, local_MS_running):
567567
# standard sweep workflow: update nodes, compute residual, log progress
568568
for hook in self.hooks:
569569
hook.pre_sweep(step=S, level_number=0)
570+
571+
S.levels[0].sweep.updateVariableCoeffs(k + 1) # update QDelta coefficients if variable preconditioner
570572
S.levels[0].sweep.update_nodes()
571573
S.levels[0].sweep.compute_residual(stage='IT_FINE')
574+
572575
for hook in self.hooks:
573576
hook.post_sweep(step=S, level_number=0)
574577

pySDC/implementations/convergence_controller_classes/adaptive_collocation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from qmat.lagrange import LagrangeApproximation
44
from pySDC.core.convergence_controller import ConvergenceController, Status
5-
from pySDC.core.collocation import CollBase
65

76

87
class AdaptiveCollocation(ConvergenceController):

pySDC/implementations/sweeper_classes/generic_implicit.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def update_nodes(self):
6565
# get number of collocation nodes for easier access
6666
M = self.coll.num_nodes
6767

68-
# update the MIN-SR-FLEX preconditioner
69-
self.updateVariableCoeffs(L.status.sweep)
70-
7168
# gather all terms which are known already (e.g. from the previous iteration)
7269
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
7370

pySDC/implementations/sweeper_classes/generic_implicit_MPI.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,6 @@ def update_nodes(self):
209209
# only if the level has been touched before
210210
assert L.status.unlocked
211211

212-
# update the MIN-SR-FLEX preconditioner
213-
self.updateVariableCoeffs(L.status.sweep)
214-
215212
# gather all terms which are known already (e.g. from the previous iteration)
216213
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
217214

pySDC/implementations/sweeper_classes/imex_1st_order.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def update_nodes(self):
7272
# get number of collocation nodes for easier access
7373
M = self.coll.num_nodes
7474

75-
# update the MIN-SR-FLEX preconditioner
76-
self.updateVariableCoeffs(L.status.sweep)
77-
7875
# gather all terms which are known already (e.g. from the previous iteration)
7976
# this corresponds to u0 + QF(u^k) - QIFI(u^k) - QEFE(u^k) + tau
8077

pySDC/implementations/sweeper_classes/imex_1st_order_MPI.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def update_nodes(self):
5555
# gather all terms which are known already (e.g. from the previous iteration)
5656
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau
5757

58-
# update the MIN-SR-FLEX preconditioner
59-
self.updateVariableCoeffs(L.status.sweep)
60-
6158
# get QF(u^k)
6259
rhs = self.integrate()
6360

0 commit comments

Comments
 (0)