Skip to content

Commit a055c4d

Browse files
committed
Status variables, short switching logics and other improvements, thanks to @pancetta and @brownbaerchen!
1 parent 2acf596 commit a055c4d

File tree

7 files changed

+178
-163
lines changed

7 files changed

+178
-163
lines changed

pySDC/implementations/problem_classes/Battery.py

Lines changed: 46 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class battery(ptype):
99
"""
1010
Example implementing the battery drain model as in the description in the PinTSimE project
11+
1112
Attributes:
1213
A: system matrix, representing the 2 ODEs
1314
t_switch: time point of the switch
@@ -17,21 +18,22 @@ class battery(ptype):
1718
def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
1819
"""
1920
Initialization routine
21+
2022
Args:
2123
problem_params (dict): custom parameters for the example
2224
dtype_u: mesh data type for solution
2325
dtype_f: mesh data type for RHS
2426
"""
2527

26-
problem_params['nvars'] = 2
27-
2828
# these parameters will be used later, so assert their existence
29-
essential_keys = ['Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
29+
essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
3030
for key in essential_keys:
3131
if key not in problem_params:
3232
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
3333
raise ParameterError(msg)
3434

35+
problem_params['nvars'] = problem_params['ncondensators'] + 1
36+
3537
# invoke super init, passing number of dofs, dtype_u and dtype_f
3638
super(battery, self).__init__(
3739
init=(problem_params['nvars'], None, np.dtype('float64')),
@@ -48,56 +50,50 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
4850
def eval_f(self, u, t):
4951
"""
5052
Routine to evaluate the RHS
53+
5154
Args:
5255
u (dtype_u): current values
5356
t (float): current time
57+
5458
Returns:
5559
dtype_f: the RHS
5660
"""
5761

5862
f = self.dtype_f(self.init, val=0.0)
5963
f.impl[:] = self.A.dot(u)
6064

61-
if self.t_switch is not None:
62-
if t >= self.t_switch:
63-
f.expl[0] = self.params.Vs / self.params.L
65+
t_switch = np.inf if self.t_switch is None else self.t_switch
6466

65-
else:
66-
f.expl[0] = 0
67+
if u[1] <= self.params.V_ref or t >= t_switch:
68+
f.expl[0] = self.params.Vs / self.params.L
6769

6870
else:
69-
if u[1] <= self.params.V_ref:
70-
f.expl[0] = self.params.Vs / self.params.L
71-
72-
else:
73-
f.expl[0] = 0
71+
f.expl[0] = 0
7472

7573
return f
7674

7775
def solve_system(self, rhs, factor, u0, t):
7876
"""
7977
Simple linear solver for (I-factor*A)u = rhs
78+
8079
Args:
8180
rhs (dtype_f): right-hand side for the linear system
8281
factor (float): abbrev. for the local stepsize (or any other factor required)
8382
u0 (dtype_u): initial guess for the iterative solver
8483
t (float): current time (e.g. for time-dependent BCs)
84+
8585
Returns:
8686
dtype_u: solution as mesh
8787
"""
8888
self.A = np.zeros((2, 2))
8989

90-
if self.t_switch is not None:
91-
if t >= self.t_switch:
92-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
93-
else:
94-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
90+
t_switch = np.inf if self.t_switch is None else self.t_switch
91+
92+
if rhs[1] <= self.params.V_ref or t >= t_switch:
93+
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
9594

9695
else:
97-
if rhs[1] <= self.params.V_ref:
98-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
99-
else:
100-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
96+
self.A[1, 1] = -1 / (self.params.C * self.params.R)
10197

10298
me = self.dtype_u(self.init)
10399
me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
@@ -106,8 +102,10 @@ def solve_system(self, rhs, factor, u0, t):
106102
def u_exact(self, t):
107103
"""
108104
Routine to compute the exact solution at time t
105+
109106
Args:
110107
t (float): current time
108+
111109
Returns:
112110
dtype_u: exact solution
113111
"""
@@ -123,12 +121,14 @@ def u_exact(self, t):
123121
def get_switching_info(self, u, t):
124122
"""
125123
Provides information about a discrete event for one subinterval.
124+
126125
Args:
127126
u (dtype_u): current values
128127
t (float): current time
128+
129129
Returns:
130130
switch_detected (bool): Indicates if a switch is found or not
131-
m_guess (np.int): Index of where the discrete event would found
131+
m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found
132132
vC_switch (list): Contains function values of switching condition (for interpolation)
133133
"""
134134

@@ -141,10 +141,7 @@ def get_switching_info(self, u, t):
141141
m_guess = m - 1
142142
break
143143

144-
vC_switch = []
145-
if switch_detected:
146-
for m in range(1, len(u)):
147-
vC_switch.append(u[m][1] - self.params.V_ref)
144+
vC_switch = [u[m][1] - self.params.V_ref for m in range(1, len(u))] if switch_detected else []
148145

149146
return switch_detected, m_guess, vC_switch
150147

@@ -158,21 +155,21 @@ def flip_switches(self):
158155

159156

160157
class battery_implicit(battery):
161-
162158
def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
163159

164-
essential_keys = ['newton_maxiter', 'newton_tol']
160+
essential_keys = ['newton_maxiter', 'newton_tol', 'ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
165161
for key in essential_keys:
166162
if key not in problem_params:
167163
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
168164
raise ParameterError(msg)
169165

166+
problem_params['nvars'] = problem_params['ncondensators'] + 1
167+
170168
# invoke super init, passing number of dofs, dtype_u and dtype_f
171169
super(battery_implicit, self).__init__(
172-
init=(problem_params['nvars'], None, np.dtype('float64')),
170+
problem_params,
173171
dtype_u=dtype_u,
174172
dtype_f=dtype_f,
175-
params=problem_params,
176173
)
177174

178175
self.newton_itercount = 0
@@ -183,43 +180,41 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
183180
def eval_f(self, u, t):
184181
"""
185182
Routine to evaluate the RHS
183+
186184
Args:
187185
u (dtype_u): current values
188186
t (float): current time
187+
189188
Returns:
190189
dtype_f: the RHS
191190
"""
192191

193192
f = self.dtype_f(self.init, val=0.0)
194193
non_f = np.zeros(2)
195194

196-
if self.t_switch is not None:
197-
if t >= self.t_switch:
198-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
199-
non_f[0] = self.params.Vs / self.params.L
200-
else:
201-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
202-
non_f[0] = 0
195+
t_switch = np.inf if self.t_switch is None else self.t_switch
196+
197+
if u[1] <= self.params.V_ref or t >= t_switch:
198+
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
199+
non_f[0] = self.params.Vs
203200

204201
else:
205-
if u[1] <= self.params.V_ref:
206-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
207-
non_f[0] = self.params.Vs / self.params.L
208-
else:
209-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
210-
non_f[0] = 0
202+
self.A[1, 1] = -1 / (self.params.C * self.params.R)
203+
non_f[0] = 0
211204

212205
f[:] = self.A.dot(u) + non_f
213206
return f
214207

215208
def solve_system(self, rhs, factor, u0, t):
216209
"""
217210
Simple Newton solver
211+
218212
Args:
219213
rhs (dtype_f): right-hand side for the linear system
220214
factor (float): abbrev. for the local stepsize (or any other factor required)
221215
u0 (dtype_u): initial guess for the iterative solver
222216
t (float): current time (e.g. for time-dependent BCs)
217+
223218
Returns:
224219
dtype_u: solution as mesh
225220
"""
@@ -228,21 +223,15 @@ def solve_system(self, rhs, factor, u0, t):
228223
non_f = np.zeros(2)
229224
self.A = np.zeros((2, 2))
230225

231-
if self.t_switch is not None:
232-
if t >= self.t_switch:
233-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
234-
non_f[0] = self.params.Vs / self.params.L
235-
else:
236-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
237-
non_f[0] = 0
226+
t_switch = np.inf if self.t_switch is None else self.t_switch
227+
228+
if rhs[1] <= self.params.V_ref or t >= t_switch:
229+
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
230+
non_f[0] = self.params.Vs
238231

239232
else:
240-
if rhs[1] <= self.params.V_ref:
241-
self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
242-
non_f[0] = self.params.Vs / self.params.L
243-
else:
244-
self.A[1, 1] = -1 / (self.params.C * self.params.R)
245-
non_f[0] = 0
233+
self.A[1, 1] = -1 / (self.params.C * self.params.R)
234+
non_f[0] = 0
246235

247236
# start newton iteration
248237
n = 0

pySDC/implementations/problem_classes/Battery_2Condensators.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,9 @@ def get_switching_info(self, u, t):
164164
msg = 'A discrete event is already found! Multiple switching handling in the same interval is not yet implemented!'
165165
raise AssertionError(msg)
166166

167-
vC_switch = []
168-
if switch_detected:
169-
for m in range(1, len(u)):
170-
vC_switch.append(u[m][k_detected] - self.params.V_ref[k_detected - 1])
167+
vC_switch = (
168+
[u[m][k_detected] - self.params.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else []
169+
)
171170

172171
return switch_detected, m_guess, vC_switch
173172

pySDC/projects/PinTSimE/battery_2condensators_model.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def post_step(self, step, level_number):
6767
def main(use_switch_estimator=True):
6868
"""
6969
A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators
70+
71+
Args:
72+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
73+
74+
Returns:
75+
description (dict): contains all information for a controller run
7076
"""
7177

7278
# initialize level parameters
@@ -83,6 +89,7 @@ def main(use_switch_estimator=True):
8389

8490
# initialize problem parameters
8591
problem_params = dict()
92+
problem_params['ncondensators'] = 2
8693
problem_params['Vs'] = 5.0
8794
problem_params['Rs'] = 0.5
8895
problem_params['C1'] = 1.0
@@ -98,7 +105,7 @@ def main(use_switch_estimator=True):
98105

99106
# initialize controller parameters
100107
controller_params = dict()
101-
controller_params['logger_level'] = 15
108+
controller_params['logger_level'] = 30
102109
controller_params['hook_class'] = log_data
103110

104111
# convergence controllers
@@ -171,12 +178,18 @@ def main(use_switch_estimator=True):
171178

172179
plot_voltages(description, recomputed, use_switch_estimator)
173180

174-
return np.mean(niters)
181+
return description
175182

176183

177184
def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
178185
"""
179186
Routine to plot the numerical solution of the model
187+
188+
Args:
189+
description(dict): contains all information for a controller run
190+
recomputed (bool): flag if the values after a restart are used or before
191+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
192+
cwd: current working directory
180193
"""
181194

182195
f = open(cwd + 'data/battery_2condensators.dat', 'rb')
@@ -200,7 +213,7 @@ def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
200213
switches = get_recomputed(stats, type='switch', sortby='time')
201214

202215
if recomputed is not None:
203-
assert len(switches) >= 1 and len(switches) >= 2, "No switches found"
216+
assert len(switches) >= 2, f"Expected at least 2 switches, got {len(switches)}!"
204217
t_switches = [v[1] for v in switches]
205218

206219
for i in range(len(t_switches)):
@@ -218,6 +231,10 @@ def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
218231
def proof_assertions_description(description, use_switch_estimator):
219232
"""
220233
Function to proof the assertions (function to get cleaner code)
234+
235+
Args:
236+
description(dict): contains all information for a controller run
237+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
221238
"""
222239

223240
assert (
@@ -227,7 +244,13 @@ def proof_assertions_description(description, use_switch_estimator):
227244
description['problem_params']['alpha'] > description['problem_params']['V_ref'][1]
228245
), 'Please set "alpha" greater than "V_ref2"'
229246

230-
assert type(description['problem_params']['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)'
247+
if description['problem_params']['ncondensators'] > 1:
248+
assert (
249+
type(description['problem_params']['V_ref']) == np.ndarray
250+
), '"V_ref" needs to be an array (of type float)'
251+
assert (
252+
description['problem_params']['ncondensators'] == np.shape(description['problem_params']['V_ref'])[0]
253+
), 'Number of reference values needs to be equal to number of condensators'
231254

232255
assert description['problem_params']['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0'
233256
assert description['problem_params']['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0'

0 commit comments

Comments
 (0)