Skip to content

Commit 32549d9

Browse files
authored
TL: first revision for the parallel SDC theory paper (#481)
* TL: dump version * TL: added MPI script for pSDC theory paper * TL: adapted script for parallel computation * TL: minor detail for output * TL: mini-fixes and black * TL: update for plots * TL: upload reference data from jusuf runs * TL: minor modifications * TL: prepared PR * TL: fix linting error * TL: added missing dependency for project tests * TL: another detail * TL: eventually <=> finally apparently
1 parent f46209f commit 32549d9

File tree

14 files changed

+402
-13
lines changed

14 files changed

+402
-13
lines changed

pySDC/core/sweeper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,15 @@ def level(self, L):
263263
@property
264264
def rank(self):
265265
return 0
266+
267+
def updateVariableCoeffs(self, k):
268+
"""
269+
Potentially update QDelta implicit coefficients if variable ...
270+
271+
Parameters
272+
----------
273+
k : int
274+
Index of the sweep (0 for initial sweep, 1 for the first one, ...).
275+
"""
276+
if self.params.QI == 'MIN-SR-FLEX':
277+
self.QI = self.get_Qdelta_implicit(qd_type="MIN-SR-FLEX", k=k)

pySDC/implementations/sweeper_classes/generic_implicit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def update_nodes(self):
6666
M = self.coll.num_nodes
6767

6868
# update the MIN-SR-FLEX preconditioner
69-
if self.params.QI == 'MIN-SR-FLEX':
70-
self.QI = self.get_Qdelta_implicit(qd_type="MIN-SR-FLEX", k=L.status.sweep)
69+
self.updateVariableCoeffs(L.status.sweep)
7170

7271
# gather all terms which are known already (e.g. from the previous iteration)
7372
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau

pySDC/implementations/sweeper_classes/generic_implicit_MPI.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def update_nodes(self):
201201
# only if the level has been touched before
202202
assert L.status.unlocked
203203

204-
# get number of collocation nodes for easier access
204+
# update the MIN-SR-FLEX preconditioner
205+
self.updateVariableCoeffs(L.status.sweep)
205206

206207
# gather all terms which are known already (e.g. from the previous iteration)
207208
# this corresponds to u0 + QF(u^k) - QdF(u^k) + tau

pySDC/playgrounds/Diagonal/dahlquist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def setParameters(cls, M=None, nodeDistr=None, quadType=None,
7171
cls.QDeltaE, cls.dtauE = genQDelta(cls.nodes, explSweep, cls.Q)
7272
cls.explSweep = explSweep
7373

74-
# Eventually update nSweep, initSweep and forceProlongation
74+
# Potentially update nSweep, initSweep and forceProlongation
7575
cls.initSweep = cls.initSweep if initSweep is None else initSweep
7676
cls.nSweep = cls.nSweep if nSweep is None else nSweep
7777
cls.forceProl = cls.forceProl if forceProl is None else forceProl

pySDC/playgrounds/Diagonal/optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def findRes(x):
5050

5151
funcEval = func(x0)
5252

53-
# Eventually add local minimum to results
53+
# Potentially add local minimum to results
5454
xOrig = findRes(x0)
5555
if xOrig:
5656
if funcEval < res[xOrig]:

pySDC/playgrounds/dedalus/advection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def getErr(nStep):
8181
sym = '^-' if useSDC else 'o-'
8282
plt.loglog(dt, err, sym, label=lbl)
8383

84-
# Eventually plot order curve
84+
# Potentially plot order curve
8585
if name in orderPlot:
8686
order = orderPlot[name]
8787
c = err[-1]/dt[-1]**order * 2

pySDC/playgrounds/dedalus/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def updateLHS(self, dt, qI, init=False):
117117
# Update LHS and LHS solvers for each subproblems
118118
for sp in solver.subproblems:
119119
if self.init:
120-
# Eventually instanciate list of solvers (ony first time step)
120+
# Instanciate list of solvers (ony first time step)
121121
sp.LHS_solvers = [None] * self.M
122122
self.init = False
123123
for i in range(self.M):

pySDC/playgrounds/dedalus/sdc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def setParameters(cls, nNodes=None, nodeType=None, quadType=None,
173173
Q=cls.Q, nodes=cls.nodes, nNodes=nNodes,
174174
nodeType=nodeType, quadType=quadType)
175175

176-
# Eventually update additional parameters
176+
# Potentially update additional parameters
177177
if forceProl is not None: cls.forceProl = forceProl
178178
if initSweep is not None: cls.initSweep = initSweep
179179
if not keepNSweeps:
@@ -302,7 +302,7 @@ def _updateLHS(self, dt, init=False):
302302
# Update LHS and LHS solvers for each subproblems
303303
for sp in solver.subproblems:
304304
if init:
305-
# Eventually instantiate list of solver (ony first time step)
305+
# Potentially instantiate list of solver (ony first time step)
306306
sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)]
307307
for k in range(self.nSweeps):
308308
for m in range(self.M):

pySDC/projects/parallelSDC_reloaded/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ dependencies:
99
- matplotlib>=3.0
1010
- dill>=0.2.6
1111
- scipy>=0.17.1
12+
- mpich
13+
- mpi4py>=3.0.0
1214
- pip
1315
- pip:
1416
- qmat>=0.1.8
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Jan 11 11:14:01 2024
5+
6+
Figures with experiments on the Allen-Cahn problem (MPI runs)
7+
"""
8+
import os
9+
import sys
10+
import json
11+
import numpy as np
12+
from mpi4py import MPI
13+
14+
from pySDC.projects.parallelSDC_reloaded.utils import solutionExact, getParamsSDC, solutionSDC, getParamsRK
15+
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI
16+
17+
PATH = '/' + os.path.join(*__file__.split('/')[:-1])
18+
SCRIPT = __file__.split('/')[-1].split('.')[0]
19+
20+
COMM_WORLD = MPI.COMM_WORLD
21+
22+
# SDC parameters
23+
nNodes = 4
24+
quadType = 'RADAU-RIGHT'
25+
nodeType = 'LEGENDRE'
26+
parEfficiency = 0.8 # 1/nNodes
27+
nSweeps = 4
28+
29+
# Problem parameters
30+
pName = "ALLEN-CAHN"
31+
tEnd = 50
32+
pParams = {
33+
"periodic": False,
34+
"nvars": 2**11 - 1,
35+
"epsilon": 0.04,
36+
}
37+
38+
# -----------------------------------------------------------------------------
39+
# %% Convergence and error VS cost plots
40+
# -----------------------------------------------------------------------------
41+
nStepsList = np.array([5, 10, 20, 50, 100, 200, 500])
42+
dtVals = tEnd / nStepsList
43+
44+
45+
def getError(uNum, uRef):
46+
if uNum is None:
47+
return np.inf
48+
return np.linalg.norm(uRef[-1, :] - uNum[-1, :], ord=2)
49+
50+
51+
def getCost(counters):
52+
_, _, tComp = counters
53+
return tComp
54+
55+
56+
try:
57+
qDelta = sys.argv[1]
58+
if qDelta.startswith("--"):
59+
qDelta = "MIN-SR-FLEX"
60+
except IndexError:
61+
qDelta = "MIN-SR-FLEX"
62+
63+
try:
64+
params = getParamsRK(qDelta)
65+
except KeyError:
66+
params = getParamsSDC(quadType=quadType, numNodes=nNodes, nodeType=nodeType, qDeltaI=qDelta, nSweeps=nSweeps)
67+
68+
useMPI = False
69+
if COMM_WORLD.Get_size() == 4 and qDelta in ["MIN-SR-NS", "MIN-SR-S", "MIN-SR-FLEX", "VDHS"]: # pragma: no cover
70+
params['sweeper_class'] = generic_implicit_MPI
71+
useMPI = True
72+
73+
errors = []
74+
costs = []
75+
76+
root = COMM_WORLD.Get_rank() == 0
77+
if root:
78+
print(f"Running simulation with {qDelta}")
79+
80+
for nSteps in nStepsList:
81+
if root:
82+
uRef = solutionExact(tEnd, nSteps, pName, **pParams)
83+
84+
uSDC, counters, parallel = solutionSDC(tEnd, nSteps, params, pName, verbose=root, noExcept=True, **pParams)
85+
86+
if root:
87+
err = getError(uSDC, uRef)
88+
errors.append(err)
89+
90+
cost = getCost(counters)
91+
costs.append(cost)
92+
93+
if COMM_WORLD.Get_rank() == 0:
94+
errors = [float(e) for e in errors]
95+
96+
print("errors : ", errors)
97+
print("tComps : ", costs)
98+
fileName = f"{PATH}/fig06_compTime.json"
99+
timings = {}
100+
if os.path.isfile(fileName):
101+
with open(fileName, "r") as f:
102+
timings = json.load(f)
103+
104+
timings[qDelta + "_MPI" * useMPI] = {"errors": errors, "costs": costs}
105+
106+
with open(fileName, 'w') as f:
107+
json.dump(timings, f, indent=4)

0 commit comments

Comments
 (0)