Skip to content

Commit 4a4cab0

Browse files
committed
TL: continued refactoring
1 parent 635af46 commit 4a4cab0

File tree

6 files changed

+130
-117
lines changed

6 files changed

+130
-117
lines changed

pySDC/playgrounds/dedalus/demo_timestepper_burger.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

pySDC/playgrounds/dedalus/demos/demo_interface_advDiff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Quick demo to solve the 1D advection-diffusion with Dedalus
55
using the pySDC interface
66
"""
7-
# Base user imports
7+
# Base python imports
88
import numpy as np
99
import matplotlib.pyplot as plt
1010

pySDC/playgrounds/dedalus/demo_interface_burger.py renamed to pySDC/playgrounds/dedalus/demos/demo_interface_burger.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
"""
4-
Demo script for the KdV-Burgers equation
4+
Demo script for the KdV-Burgers equation, using the pySDC interface
55
"""
66
import numpy as np
77
import matplotlib.pyplot as plt
@@ -12,20 +12,28 @@
1212
from pySDC.playgrounds.dedalus.interface import DedalusProblem, DedalusSweeperIMEX
1313
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
1414

15-
# Space parameters
16-
xEnd = 10
17-
nX = 512
18-
nu = 1e-4
19-
b = 2e-4
15+
# -----------------------------------------------------------------------------
16+
# User parameters
17+
# -----------------------------------------------------------------------------
18+
xEnd = 10 # space domain size
19+
nX = 512 # number of point in space
20+
nu = 1e-4 # diffusion coefficient
21+
b = 2e-4 # hyper-diffusion coefficient
2022

21-
# Time-integration parameters
23+
24+
# -- time integration
2225
nSweeps = 4
2326
nNodes = 4
2427
tEnd = 10
2528
nSteps = 5000
29+
30+
# -----------------------------------------------------------------------------
31+
# Solver setup
32+
# -----------------------------------------------------------------------------
2633
timeStep = tEnd / nSteps
2734

2835
pData = buildKdVBurgerProblem(nX, xEnd, nu, b)
36+
problem, u, x = [pData[key] for key in ["problem", "u", "x"]]
2937

3038
description = {
3139
# Sweeper and its parameters
@@ -36,7 +44,7 @@
3644
"node_type": "LEGENDRE",
3745
"initial_guess": "copy",
3846
"do_coll_update": False,
39-
"QI": "MIN-SR-S",
47+
"QI": "MIN-SR-FLEX",
4048
"QE": "PIC",
4149
'skip_residual_computation':
4250
('IT_CHECK', 'IT_DOWN', 'IT_UP', 'IT_FINE', 'IT_COARSE'),
@@ -53,13 +61,14 @@
5361
},
5462
"problem_class": DedalusProblem,
5563
"problem_params": {
56-
'problem': pData["problem"],
64+
'problem': problem,
5765
'nNodes': nNodes,
5866
}
5967
}
6068

61-
# Main loop
62-
u, x = [pData[key] for key in ["u", "x"]]
69+
# -----------------------------------------------------------------------------
70+
# Simulation run
71+
# -----------------------------------------------------------------------------
6372
u.change_scales(1)
6473
u_list = [np.copy(u['g'])]
6574
t_list = [0]
@@ -84,8 +93,9 @@
8493
u_list.append(np.copy(u['g']))
8594
t_list.append(tVals[i])
8695

87-
88-
# Plot
96+
# -----------------------------------------------------------------------------
97+
# Plotting solution in real space
98+
# -----------------------------------------------------------------------------
8999
plt.figure(figsize=(6, 4))
90100
plt.pcolormesh(
91101
x.ravel(), np.array(t_list), np.array(u_list), cmap='RdBu_r',
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Demo script for the KdV-Burgers equation, using the pySDC timestepper
5+
"""
6+
# Base python imports
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
10+
import logging
11+
logger = logging.getLogger(__name__)
12+
13+
# pySDC imports
14+
from pySDC.playgrounds.dedalus.problems import buildKdVBurgerProblem
15+
from pySDC.playgrounds.dedalus.timestepper import SpectralDeferredCorrectionIMEX
16+
17+
# Dedalus import (for alternative time-stepper)
18+
import dedalus.public as d3
19+
20+
# -----------------------------------------------------------------------------
21+
# User parameters
22+
# -----------------------------------------------------------------------------
23+
xEnd = 10 # space domain size
24+
nX = 512 # number of point in space
25+
nu = 1e-4 # diffusion coefficient
26+
b = 2e-4 # hyper-diffusion coefficient
27+
28+
# -- time integration
29+
tEnd = 10
30+
nSteps = 5000
31+
SpectralDeferredCorrectionIMEX.setParameters(
32+
nSweeps=4,
33+
nNodes=4,
34+
implSweep="MIN-SR-FLEX",
35+
explSweep="PIC")
36+
useSDC = True
37+
38+
# -----------------------------------------------------------------------------
39+
# Solver setup
40+
# -----------------------------------------------------------------------------
41+
timestepper = SpectralDeferredCorrectionIMEX if useSDC else d3.RK443
42+
timestep = tEnd/nSteps
43+
44+
pData = buildKdVBurgerProblem(nX, xEnd, nu, b)
45+
problem, u, x = [pData[key] for key in ["problem", "u", "x"]]
46+
47+
solver = problem.build_solver(timestepper)
48+
solver.stop_sim_time = tEnd
49+
50+
# -----------------------------------------------------------------------------
51+
# Simulation run
52+
# -----------------------------------------------------------------------------
53+
u.change_scales(1)
54+
u_list = [np.copy(u['g'])]
55+
t_list = [solver.sim_time]
56+
i = 0
57+
while solver.proceed:
58+
solver.step(timestep)
59+
if solver.iteration % 100 == 0:
60+
print(f"step {solver.iteration}/{nSteps}")
61+
logger.info('Iteration=%i, Time=%e, dt=%e' %(solver.iteration, solver.sim_time, timestep))
62+
if solver.iteration % 25 == 0:
63+
u.change_scales(1)
64+
u_list.append(np.copy(u['g']))
65+
t_list.append(solver.sim_time)
66+
i += 1
67+
solver.log_stats()
68+
69+
# -----------------------------------------------------------------------------
70+
# Plotting solution in real space
71+
# -----------------------------------------------------------------------------
72+
plt.figure(figsize=(6, 4))
73+
plt.pcolormesh(x.ravel(), np.array(t_list), np.array(u_list), cmap='RdBu_r', shading='gouraud', rasterized=True, clim=(-0.8, 0.8))
74+
plt.xlim(0, xEnd)
75+
plt.ylim(0, tEnd)
76+
plt.xlabel('x')
77+
plt.ylabel('t')
78+
plt.title(r'KdV-Burgers, $(\nu,b)='f'({nu},{b})$')
79+
plt.tight_layout()
80+
plt.savefig(f"demo_timestepper_burger{'_SDC' if useSDC else ''}.png")

pySDC/playgrounds/dedalus/interface/sweeper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def update_nodes(self):
9696

9797
# Add quadrature terms
9898
for j in range(M):
99-
axpy(a=dt*q[m, j], x=Fk[j].data, y=RHS.data)
100-
axpy(a=-dt*q[m, j], x=LXk[j].data, y=RHS.data)
99+
axpy(a=dt*q[m, j], x=(Fk[j].data - LXk[j].data), y=RHS.data)
101100

102101
# Add F and LX terms from iteration k+1
103102
for j in range(m):
@@ -111,6 +110,8 @@ def update_nodes(self):
111110
axpy(a=dt*qI[m, m], x=LXk[m].data, y=RHS.data)
112111

113112
# Solve system and store node solution in solver state
113+
if self.genQI.isKDependent():
114+
P.updateLHS(dt, qI)
114115
P.solveAndStoreState(m)
115116

116117
# Evaluate and store LX with current state

pySDC/playgrounds/dedalus/timestepper/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def leftIsNode(cls):
221221
def doProlongation(cls):
222222
return not cls.rightIsNode or cls.forceProl
223223

224+
@classmethod
225+
def isKDependent(cls):
226+
return not np.all(cls.QDeltaI == cls.QDeltaI[0])
227+
224228

225229
# -----------------------------------------------------------------------------
226230
# Dedalus based IMEX timeintegrator class
@@ -312,15 +316,26 @@ def _updateLHS(self, dt, init=False):
312316
if init:
313317
# Potentially instantiate list of solver (ony first time step)
314318
sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)]
315-
for k in range(self.nSweeps):
319+
if self.isKDependent():
320+
for k in range(self.nSweeps):
321+
for m in range(self.M):
322+
if solver.store_expanded_matrices:
323+
raise NotImplementedError("code correction required")
324+
else:
325+
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
326+
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
327+
else:
316328
for m in range(self.M):
317-
if solver.store_expanded_matrices:
318-
raise NotImplementedError("code correction required")
319-
np.copyto(sp.LHS.data, sp.M_exp.data)
320-
self.axpy(a=dt*qI[k, m, m], x=sp.L_exp.data, y=sp.LHS.data)
321-
else:
322-
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
323-
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
329+
for k in range(self.nSweeps):
330+
if k == 0:
331+
if solver.store_expanded_matrices:
332+
raise NotImplementedError("code correction required")
333+
else:
334+
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
335+
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
336+
else:
337+
sp.LHS_solvers[k][m] = sp.LHS_solvers[0][m]
338+
324339
if self.initSweep == "QDELTA":
325340
raise NotImplementedError()
326341

0 commit comments

Comments
 (0)