Skip to content

Commit c4d388f

Browse files
committed
starting Pi-line model
1 parent 6069797 commit c4d388f

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import numpy as np
2+
3+
from pySDC.core.Errors import ParameterError, ProblemError
4+
from pySDC.core.Problem import ptype
5+
from pySDC.implementations.datatype_classes.mesh import mesh
6+
7+
8+
# noinspection PyUnusedLocal
9+
class piline(ptype):
10+
"""
11+
Example implementing the Piline model as in the description in the PinTSimE project
12+
13+
Attributes:
14+
A: system matrix, representing the 9 ODEs
15+
"""
16+
17+
def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
18+
"""
19+
Initialization routine
20+
21+
Args:
22+
problem_params (dict): custom parameters for the example
23+
dtype_u: mesh data type for solution
24+
dtype_f: mesh data type for RHS
25+
"""
26+
27+
problem_params['nvars'] = 9
28+
29+
# these parameters will be used later, so assert their existence
30+
essential_keys = ['Vs', 'Rs', 'C1', 'Rpi', 'Lpi', 'C2', 'Rl']
31+
for key in essential_keys:
32+
if key not in problem_params:
33+
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
34+
raise ParameterError(msg)
35+
36+
# invoke super init, passing number of dofs, dtype_u and dtype_f
37+
super(piline, self).__init__(init=(problem_params['nvars'], None, np.dtype('float64')),
38+
dtype_u=dtype_u, dtype_f=dtype_f, params=problem_params)
39+
40+
# compute dx and get discretization matrix A
41+
self.A = np.zeros((9, 9))
42+
self.A[0, 4] = 1 / self.params.C1
43+
self.A[1, 1] = -self.params.Rpi / self.params.Lpi
44+
self.A[1, 2] = self.params.Rpi / self.params.Lpi
45+
self.A[1, 4] = 1 / self.params.C1
46+
self.A[2, 7] = 1 / self.params.C2
47+
self.A[3, 4] = 1 / (self.params.Rs * self.params.C1)
48+
self.A[4, 1] = 1 / self.params.Lpi
49+
self.A[4, 2] = -1 / self.params.Lpi
50+
self.A[4, 4] = -1 / (self.params.Rs * self.params.C1)
51+
self.A[5, 1] = -1 / self.params.Lpi
52+
self.A[5, 2] = 1 / self.params.Lpi
53+
self.A[6, 1] = -1 / self.params.Lpi
54+
self.A[6, 2] = 1 / self.params.Lpi
55+
self.A[7, 1] = -1 / self.params.Lpi
56+
self.A[7, 2] = 1 / self.params.Lpi
57+
self.A[7, 7] = -1 / (self.params.Rl * self.params.C2)
58+
self.A[8, 7] = 1 / (self.params.Rl * self.params.C2)
59+
60+
def eval_f(self, u, t):
61+
"""
62+
Routine to evaluate the RHS
63+
64+
Args:
65+
u (dtype_u): current values
66+
t (float): current time
67+
68+
Returns:
69+
dtype_f: the RHS
70+
"""
71+
72+
f = self.dtype_f(self.init)
73+
f[:] = self.A.dot(u)
74+
return f
75+
76+
def solve_system(self, rhs, factor, u0, t):
77+
"""
78+
Simple linear solver for (I-factor*A)u = rhs
79+
80+
Args:
81+
rhs (dtype_f): right-hand side for the linear system
82+
factor (float): abbrev. for the local stepsize (or any other factor required)
83+
u0 (dtype_u): initial guess for the iterative solver
84+
t (float): current time (e.g. for time-dependent BCs)
85+
86+
Returns:
87+
dtype_u: solution as mesh
88+
"""
89+
90+
me = self.dtype_u(self.init)
91+
me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
92+
return me
93+
94+
def u_exact(self, t):
95+
"""
96+
Routine to compute the exact solution at time t
97+
98+
Args:
99+
t (float): current time
100+
101+
Returns:
102+
dtype_u: exact solution
103+
"""
104+
105+
me = self.dtype_u(self.init)
106+
107+
me[0] = 0.0 # v1
108+
me[1] = 0.0 # v2
109+
me[2] = 0.0 # v3
110+
me[3] = 0.0 # i_Vs
111+
me[4] = self.params.Vs / self.params.Rs # i_C1
112+
me[5] = 0.0 # i_Rpi
113+
me[6] = 0.0 # i_Lpi
114+
me[7] = 0.0 # i_C2
115+
me[8] = 0.0 # i_Rl
116+
117+
return me

pySDC/playgrounds/Piline/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pySDC.core.Hooks import hooks
2+
3+
4+
class log_data(hooks):
5+
6+
def post_step(self, step, level_number):
7+
8+
super(log_data, self).post_step(step, level_number)
9+
10+
# some abbreviations
11+
L = step.levels[level_number]
12+
13+
L.sweep.compute_end_point()
14+
15+
self.add_to_stats(process=step.status.slot, time=L.time, level=L.level_index, iter=0,
16+
sweep=L.status.sweep, type='v1', value=L.uend[0])
17+
self.add_to_stats(process=step.status.slot, time=L.time, level=L.level_index, iter=0,
18+
sweep=L.status.sweep, type='v2', value=L.uend[1])
19+
self.add_to_stats(process=step.status.slot, time=L.time, level=L.level_index, iter=0,
20+
sweep=L.status.sweep, type='v3', value=L.uend[2])
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
import dill
3+
4+
from pySDC.helpers.stats_helper import filter_stats, sort_stats
5+
from pySDC.implementations.collocation_classes.gauss_radau_right import CollGaussRadau_Right
6+
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
7+
from pySDC.implementations.problem_classes.Piline import piline
8+
# from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
9+
from pySDC.implementations.sweeper_classes.generic_LU import generic_LU
10+
from pySDC.playgrounds.Piline.log_data import log_data
11+
import pySDC.helpers.plot_helper as plt_helper
12+
13+
def main():
14+
"""
15+
A simple test program to do PFASST runs for the heat equation
16+
"""
17+
18+
# initialize level parameters
19+
level_params = dict()
20+
level_params['restol'] = 1E-10
21+
level_params['dt'] = 1
22+
23+
# initialize sweeper parameters
24+
sweeper_params = dict()
25+
sweeper_params['collocation_class'] = CollGaussRadau_Right
26+
sweeper_params['num_nodes'] = 1
27+
# sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
28+
29+
# initialize problem parameters
30+
problem_params = dict()
31+
problem_params['Vs'] = 100.0
32+
problem_params['Rs'] = 1.0
33+
problem_params['C1'] = 1.0
34+
problem_params['Rpi'] = 0.2
35+
problem_params['C2'] = 1.0
36+
problem_params['Lpi'] = 1.0
37+
problem_params['Rl'] = 5.0
38+
39+
# initialize step parameters
40+
step_params = dict()
41+
step_params['maxiter'] = 1
42+
43+
# initialize controller parameters
44+
controller_params = dict()
45+
controller_params['logger_level'] = 20
46+
controller_params['hook_class'] = log_data
47+
48+
# fill description dictionary for easy step instantiation
49+
description = dict()
50+
description['problem_class'] = piline # pass problem class
51+
description['problem_params'] = problem_params # pass problem parameters
52+
description['sweeper_class'] = generic_LU # pass sweeper
53+
description['sweeper_params'] = sweeper_params # pass sweeper parameters
54+
description['level_params'] = level_params # pass level parameters
55+
description['step_params'] = step_params # pass step parameters
56+
57+
# set time parameters
58+
t0 = 0.0
59+
Tend = 10
60+
61+
# instantiate controller
62+
controller = controller_nonMPI(num_procs=1, controller_params=controller_params,
63+
description=description)
64+
65+
# get initial values on finest level
66+
P = controller.MS[0].levels[0].prob
67+
uinit = P.u_exact(t0)
68+
69+
# call main function to get things done...
70+
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
71+
72+
fname = 'data/piline.dat'
73+
f = open(fname, 'wb')
74+
dill.dump(stats, f)
75+
f.close()
76+
77+
78+
def plot_voltages(cwd='./'):
79+
f = open(cwd + 'data/piline.dat', 'rb')
80+
stats = dill.load(f)
81+
f.close()
82+
83+
# convert filtered statistics to list of iterations count, sorted by process
84+
v1 = sort_stats(filter_stats(stats, type='v1'), sortby='time')
85+
v2 = sort_stats(filter_stats(stats, type='v2'), sortby='time')
86+
v3 = sort_stats(filter_stats(stats, type='v3'), sortby='time')
87+
88+
times = [v[0] for v in v1]
89+
90+
# plt_helper.setup_mpl()
91+
plt_helper.plt.plot(times, [v[1] for v in v1], label='v1')
92+
plt_helper.plt.plot(times, [v[1] for v in v2], label='v2')
93+
plt_helper.plt.plot(times, [v[1] for v in v3], label='v3')
94+
plt_helper.plt.legend()
95+
96+
plt_helper.plt.show()
97+
98+
99+
100+
if __name__ == "__main__":
101+
main()
102+
plot_voltages()

0 commit comments

Comments
 (0)