Skip to content

Commit a822639

Browse files
committed
porting fwsw project (WIP, 3/7)
1 parent 002bdbb commit a822639

File tree

7 files changed

+324
-222
lines changed

7 files changed

+324
-222
lines changed

examples/fwsw/plot_stab_vs_k.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

examples/fwsw/plot_stability.py

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import matplotlib
2+
3+
matplotlib.use('Agg')
4+
5+
import numpy as np
6+
from matplotlib import pyplot as plt
7+
from pylab import rcParams
8+
from matplotlib.ticker import ScalarFormatter
9+
10+
from pySDC.implementations.problem_classes.FastWaveSlowWave_0D import swfw_scalar
11+
from pySDC.implementations.datatype_classes.complex_mesh import mesh, rhs_imex_mesh
12+
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
13+
from pySDC.implementations.collocation_classes.gauss_radau_right import CollGaussRadau_Right
14+
15+
from pySDC.core.Step import step
16+
17+
18+
# noinspection PyShadowingNames
19+
def compute_stab_vs_k(slow_resolved):
20+
"""
21+
Routine to compute modulus of the stability function
22+
23+
Args:
24+
slow_resolved (bool): switch to compute lambda_slow = 1 or lambda_slow = 4
25+
26+
Returns:
27+
numpy.ndarray: number of nodes
28+
numpy.ndarray: number of iterations
29+
numpy.ndarray: moduli
30+
"""
31+
32+
mvals = [2, 3, 4]
33+
kvals = np.arange(1, 10)
34+
lambda_fast = 10j
35+
36+
# PLOT EITHER FOR lambda_slow = 1 (resolved) OR lambda_slow = 4 (unresolved)
37+
if slow_resolved:
38+
lambda_slow = 1j
39+
else:
40+
lambda_slow = 4j
41+
stabval = np.zeros((np.size(mvals), np.size(kvals)))
42+
43+
problem_params = dict()
44+
# SET VALUE FOR lambda_slow AND VALUES FOR lambda_fast ###
45+
problem_params['lambda_s'] = np.array([0.0])
46+
problem_params['lambda_f'] = np.array([0.0])
47+
problem_params['u0'] = 1.0
48+
49+
# initialize sweeper parameters
50+
sweeper_params = dict()
51+
# SET TYPE AND NUMBER OF QUADRATURE NODES ###
52+
sweeper_params['collocation_class'] = CollGaussRadau_Right
53+
sweeper_params['do_coll_update'] = True
54+
55+
# initialize level parameters
56+
level_params = dict()
57+
level_params['dt'] = 1.0
58+
59+
# fill description dictionary for easy step instantiation
60+
description = dict()
61+
description['problem_class'] = swfw_scalar # pass problem class
62+
description['problem_params'] = problem_params # pass problem parameters
63+
description['dtype_u'] = mesh # pass data type for u
64+
description['dtype_f'] = rhs_imex_mesh # pass data type for f
65+
description['sweeper_class'] = imex_1st_order # pass sweeper (see part B)
66+
67+
description['level_params'] = level_params # pass level parameters
68+
description['step_params'] = dict() # pass step parameters
69+
70+
for i in range(0, np.size(mvals)):
71+
72+
sweeper_params['num_nodes'] = mvals[i]
73+
description['sweeper_params'] = sweeper_params # pass sweeper parameters
74+
75+
# now the description contains more or less everything we need to create a step
76+
S = step(description=description)
77+
78+
L = S.levels[0]
79+
80+
nnodes = L.sweep.coll.num_nodes
81+
82+
for k in range(0, np.size(kvals)):
83+
Kmax = kvals[k]
84+
Mat_sweep = L.sweep.get_scalar_problems_manysweep_mat(nsweeps=Kmax, lambdas=[lambda_fast, lambda_slow])
85+
if L.sweep.params.do_coll_update:
86+
stab_fh = 1.0 + (lambda_fast + lambda_slow) * L.sweep.coll.weights.dot(Mat_sweep.dot(np.ones(nnodes)))
87+
else:
88+
q = np.zeros(nnodes)
89+
q[nnodes - 1] = 1.0
90+
stab_fh = q.dot(Mat_sweep.dot(np.ones(nnodes)))
91+
stabval[i, k] = np.absolute(stab_fh)
92+
93+
return mvals, kvals, stabval
94+
95+
96+
# noinspection PyShadowingNames
97+
def plot_stab_vs_k(slow_resolved, mvals, kvals, stabval):
98+
"""
99+
Plotting routine for moduli
100+
101+
Args:
102+
slow_resolved (bool): switch for lambda_slow
103+
mvals (numpy.ndarray): number of nodes
104+
kvals (numpy.ndarray): number of iterations
105+
stabval (numpy.ndarray): moduli
106+
"""
107+
108+
rcParams['figure.figsize'] = 2.5, 2.5
109+
fig = plt.figure()
110+
fs = 8
111+
plt.plot(kvals, stabval[0, :], 'o-', color='b', label=("M=%2i" % mvals[0]), markersize=fs - 2)
112+
plt.plot(kvals, stabval[1, :], 's-', color='r', label=("M=%2i" % mvals[1]), markersize=fs - 2)
113+
plt.plot(kvals, stabval[2, :], 'd-', color='g', label=("M=%2i" % mvals[2]), markersize=fs - 2)
114+
plt.plot(kvals, 1.0 + 0.0 * kvals, '--', color='k')
115+
plt.xlabel('Number of iterations K', fontsize=fs)
116+
plt.ylabel(r'Modulus of stability function $\left| R \right|$', fontsize=fs)
117+
plt.ylim([0.0, 1.2])
118+
if slow_resolved:
119+
plt.legend(loc='upper right', fontsize=fs, prop={'size': fs})
120+
else:
121+
plt.legend(loc='lower left', fontsize=fs, prop={'size': fs})
122+
123+
plt.gca().get_xaxis().get_major_formatter().labelOnlyBase = False
124+
plt.gca().get_xaxis().set_major_formatter(ScalarFormatter())
125+
# plt.show()
126+
if slow_resolved:
127+
filename = 'stab_vs_k_resolved.png'
128+
else:
129+
filename = 'stab_vs_k_unresolved.png'
130+
131+
fig.savefig(filename, bbox_inches='tight')
132+
133+
134+
if __name__ == "__main__":
135+
mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=True)
136+
print(np.amax(stabval))
137+
plot_stab_vs_k(True, mvals, kvals, stabval)
138+
mvals, kvals, stabval = compute_stab_vs_k(slow_resolved=False)
139+
print(np.amax(stabval))
140+
plot_stab_vs_k(False, mvals, kvals, stabval)

0 commit comments

Comments
 (0)