Skip to content

Commit 3749e57

Browse files
committed
bugfix 2d example
1 parent 161459e commit 3749e57

File tree

2 files changed

+140
-6
lines changed

2 files changed

+140
-6
lines changed

pySDC/implementations/problem_classes/HeatEquation_2D_FD_periodic.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import scipy.sparse as sp
5-
from scipy.sparse.linalg import splu
5+
from scipy.sparse.linalg import cg
66

77
from pySDC.core.Problem import ptype
88
from pySDC.core.Errors import ParameterError, ProblemError
@@ -75,6 +75,9 @@ def __get_A(N, nu, dx):
7575
doffsets = np.concatenate((offsets, np.delete(offsets, zero_pos - 1) - N[0]))
7676

7777
A = sp.diags(dstencil, doffsets, shape=(N[0], N[0]), format='csc')
78+
# stencil = [1, -2, 1]
79+
# A = sp.diags(stencil, [-1, 0, 1], shape=(N[0], N[0]), format='csc')
80+
7881
A = sp.kron(A, sp.eye(N[0])) + sp.kron(sp.eye(N[1]), A)
7982
A *= nu / (dx ** 2)
8083

@@ -93,7 +96,7 @@ def eval_f(self, u, t):
9396
"""
9497

9598
f = self.dtype_f(self.init)
96-
f.values = self.A.dot(u.values)
99+
f.values = self.A.dot(u.values.flatten()).reshape(self.params.nvars)
97100
return f
98101

99102
def solve_system(self, rhs, factor, u0, t):
@@ -111,8 +114,9 @@ def solve_system(self, rhs, factor, u0, t):
111114
"""
112115

113116
me = self.dtype_u(self.init)
114-
L = splu(sp.eye(self.params.nvars, format='csc') - factor * self.A)
115-
me.values = L.solve(rhs.values)
117+
me.values = cg(sp.eye(self.params.nvars[0] * self.params.nvars[1], format='csc') - factor * self.A,
118+
rhs.values.flatten(), x0=u0.values.flatten(), tol=1E-12)[0]
119+
me.values = me.values.reshape(self.params.nvars)
116120
return me
117121

118122
def u_exact(self, t):
@@ -128,7 +132,7 @@ def u_exact(self, t):
128132

129133
me = self.dtype_u(self.init)
130134
xvalues = np.array([i * self.dx for i in range(self.params.nvars[0])])
131-
me.values = np.kron(np.sin(np.pi * self.params.freq * xvalues), np.sin(np.pi * self.params.freq * xvalues)) * \
135+
xv, yv = np.meshgrid(xvalues, xvalues)
136+
me.values = np.sin(np.pi * self.params.freq * xv) * np.sin(np.pi * self.params.freq * yv) * \
132137
np.exp(-t * self.params.nu * (np.pi * self.params.freq) ** 2)
133-
me.values = me.values.flatten()
134138
return me
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# import pySDC.helpers.plot_helper as plt_helper
2+
#
3+
# import pickle
4+
# import os
5+
import numpy as np
6+
7+
from pySDC.implementations.datatype_classes.mesh import mesh, rhs_imex_mesh
8+
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
9+
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
10+
from pySDC.implementations.collocation_classes.gauss_radau_right import CollGaussRadau_Right
11+
from pySDC.implementations.controller_classes.allinclusive_multigrid_nonMPI import allinclusive_multigrid_nonMPI
12+
from pySDC.implementations.problem_classes.HeatEquation_1D_FD_periodic import heat1d_periodic
13+
from pySDC.implementations.problem_classes.HeatEquation_1D_FD_forced import heat1d_forced
14+
from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
15+
16+
from pySDC.helpers.stats_helper import filter_stats, sort_stats
17+
18+
19+
def main():
20+
# initialize level parameters
21+
level_params = dict()
22+
level_params['restol'] = 1E-12
23+
level_params['dt'] = None
24+
25+
# This comes as read-in for the step class (this is optional!)
26+
step_params = dict()
27+
step_params['maxiter'] = None
28+
29+
# This comes as read-in for the problem class
30+
problem_params = dict()
31+
problem_params['nu'] = 1.0
32+
problem_params['freq'] = 2
33+
problem_params['nvars'] = [2 ** 14 - 1]#, 2 ** 13]
34+
35+
# This comes as read-in for the sweeper class
36+
sweeper_params = dict()
37+
sweeper_params['collocation_class'] = CollGaussRadau_Right
38+
sweeper_params['num_nodes'] = 3
39+
sweeper_params['QI'] = 'IE'
40+
sweeper_params['spread'] = False
41+
sweeper_params['do_coll_update'] = False
42+
43+
# initialize space transfer parameters
44+
space_transfer_params = dict()
45+
space_transfer_params['rorder'] = 2
46+
space_transfer_params['iorder'] = 2
47+
space_transfer_params['periodic'] = False
48+
49+
# initialize controller parameters
50+
controller_params = dict()
51+
controller_params['logger_level'] = 30
52+
53+
# Fill description dictionary for easy hierarchy creation
54+
description = dict()
55+
description['problem_class'] = heat1d_forced
56+
description['dtype_u'] = mesh
57+
description['dtype_f'] = rhs_imex_mesh
58+
description['sweeper_class'] = imex_1st_order
59+
description['sweeper_params'] = sweeper_params
60+
description['step_params'] = step_params
61+
description['level_params'] = level_params
62+
description['problem_params'] = problem_params
63+
# description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class
64+
# description['space_transfer_params'] = space_transfer_params # pass paramters for spatial transfer
65+
66+
67+
# setup parameters "in time"
68+
t0 = 0
69+
Tend = 2.0
70+
71+
dt_list = [Tend / 2 ** i for i in range(0, 4)]
72+
niter_list = [100]#[1, 2, 3, 4]
73+
74+
for niter in niter_list:
75+
76+
err = 0
77+
for dt in dt_list:
78+
79+
print('Working with dt = %s and k = %s iterations...' % (dt, niter))
80+
81+
description['step_params']['maxiter'] = niter
82+
description['level_params']['dt'] = dt
83+
84+
# instantiate the controller
85+
controller = allinclusive_multigrid_nonMPI(num_procs=1, controller_params=controller_params,
86+
description=description)
87+
88+
# get initial values on finest level
89+
P = controller.MS[0].levels[0].prob
90+
uinit = P.u_exact(t0)
91+
92+
# call main function to get things done...
93+
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
94+
95+
# compute exact solution and compare
96+
uex = P.u_exact(Tend)
97+
err_new = abs(uex - uend)
98+
99+
print(' error at time %s: %s' % (Tend, err_new))
100+
if err > 0:
101+
print(' order of accuracy: %6.4f' % (np.log(err / err_new) / np.log(2)))
102+
103+
err = err_new
104+
105+
# # filter statistics by type (number of iterations)
106+
# filtered_stats = filter_stats(stats, type='niter')
107+
#
108+
# # convert filtered statistics to list of iterations count, sorted by process
109+
# iter_counts = sort_stats(filtered_stats, sortby='time')
110+
#
111+
# # compute and print statistics
112+
# niters = np.array([item[1] for item in iter_counts])
113+
# out = ' Mean number of iterations: %4.2f' % np.mean(niters)
114+
# print(out)
115+
# out = ' Range of values for number of iterations: %2i ' % np.ptp(niters)
116+
# # f.write(out + '\n')
117+
# print(out)
118+
# out = ' Position of max/min number of iterations: %2i -- %2i' % \
119+
# (int(np.argmax(niters)), int(np.argmin(niters)))
120+
# # f.write(out + '\n')
121+
# print(out)
122+
# out = ' Std and var for number of iterations: %4.2f -- %4.2f' % \
123+
# (float(np.std(niters)), float(np.var(niters)))
124+
# # f.write(out + '\n')
125+
# # f.write(out + '\n')
126+
# print(out)
127+
print()
128+
129+
if __name__ == "__main__":
130+
main()

0 commit comments

Comments
 (0)