Skip to content

Commit e3c7328

Browse files
committed
One hook for everything + list comprehensions
1 parent 64f0f90 commit e3c7328

File tree

4 files changed

+104
-148
lines changed

4 files changed

+104
-148
lines changed

pySDC/projects/PinTSimE/battery_2capacitors_model.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
controller_run,
1212
generate_description,
1313
get_recomputed,
14+
log_data,
1415
proof_assertions_description,
1516
)
1617
from pySDC.projects.PinTSimE.piline_model import setup_mpl
@@ -20,54 +21,6 @@
2021
from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
2122

2223

23-
class log_data(hooks):
24-
def post_step(self, step, level_number):
25-
26-
super(log_data, self).post_step(step, level_number)
27-
28-
# some abbreviations
29-
L = step.levels[level_number]
30-
31-
L.sweep.compute_end_point()
32-
33-
self.add_to_stats(
34-
process=step.status.slot,
35-
time=L.time + L.dt,
36-
level=L.level_index,
37-
iter=0,
38-
sweep=L.status.sweep,
39-
type='current L',
40-
value=L.uend[0],
41-
)
42-
self.add_to_stats(
43-
process=step.status.slot,
44-
time=L.time + L.dt,
45-
level=L.level_index,
46-
iter=0,
47-
sweep=L.status.sweep,
48-
type='voltage C1',
49-
value=L.uend[1],
50-
)
51-
self.add_to_stats(
52-
process=step.status.slot,
53-
time=L.time + L.dt,
54-
level=L.level_index,
55-
iter=0,
56-
sweep=L.status.sweep,
57-
type='voltage C2',
58-
value=L.uend[2],
59-
)
60-
self.add_to_stats(
61-
process=step.status.slot,
62-
time=L.time,
63-
level=L.level_index,
64-
iter=0,
65-
sweep=L.status.sweep,
66-
type='restart',
67-
value=int(step.status.get('restart')),
68-
)
69-
70-
7124
def run():
7225
"""
7326
Executes the simulation for the battery model using the IMEX sweeper and plot the results
@@ -126,17 +79,17 @@ def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimato
12679
f.close()
12780

12881
# convert filtered statistics to list of iterations count, sorted by process
129-
cL = get_sorted(stats, type='current L', recomputed=recomputed, sortby='time')
130-
vC1 = get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time')
131-
vC2 = get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time')
82+
cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
83+
vC1 = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
84+
vC2 = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
13285

133-
times = [v[0] for v in cL]
86+
times = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
13487

13588
setup_mpl()
13689
fig, ax = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
137-
ax.plot(times, [v[1] for v in cL], label='$i_L$')
138-
ax.plot(times, [v[1] for v in vC1], label='$v_{C_1}$')
139-
ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$')
90+
ax.plot(times, cL, label='$i_L$')
91+
ax.plot(times, vC1, label='$v_{C_1}$')
92+
ax.plot(times, vC2, label='$v_{C_2}$')
14093

14194
if use_switch_estimator:
14295
switches = get_recomputed(stats, type='switch', sortby='time')
@@ -242,9 +195,9 @@ def get_data_dict(stats, use_switch_estimator, recomputed=False):
242195
"""
243196

244197
data = dict()
245-
data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1]
246-
data['vC1'] = np.array(get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time'))[:, 1]
247-
data['vC2'] = np.array(get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time'))[:, 1]
198+
data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
199+
data['vC1'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
200+
data['vC2'] = np.array([me[1][2] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
248201
data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1]
249202
data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1]
250203
data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])

pySDC/projects/PinTSimE/battery_model.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,8 @@ def post_step(self, step, level_number):
3232
level=L.level_index,
3333
iter=0,
3434
sweep=L.status.sweep,
35-
type='current L',
36-
value=L.uend[0],
37-
)
38-
self.add_to_stats(
39-
process=step.status.slot,
40-
time=L.time + L.dt,
41-
level=L.level_index,
42-
iter=0,
43-
sweep=L.status.sweep,
44-
type='voltage C',
45-
value=L.uend[1],
35+
type='u',
36+
value=L.uend,
4637
)
4738
self.add_to_stats(
4839
process=step.status.slot,
@@ -266,16 +257,16 @@ def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimato
266257
f.close()
267258

268259
# convert filtered statistics to list of iterations count, sorted by process
269-
cL = get_sorted(stats, type='current L', recomputed=False, sortby='time')
270-
vC = get_sorted(stats, type='voltage C', recomputed=False, sortby='time')
260+
cL = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
261+
vC = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
271262

272-
times = [v[0] for v in cL]
263+
times = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)])
273264

274265
setup_mpl()
275266
fig, ax = plt_helper.plt.subplots(1, 1, figsize=(3, 3))
276267
ax.set_title('Simulation of {} using {}'.format(problem, sweeper), fontsize=10)
277-
ax.plot(times, [v[1] for v in cL], label=r'$i_L$')
278-
ax.plot(times, [v[1] for v in vC], label=r'$v_C$')
268+
ax.plot(times, cL, label=r'$i_L$')
269+
ax.plot(times, vC, label=r'$v_C$')
279270

280271
if use_switch_estimator:
281272
switches = get_recomputed(stats, type='switch', sortby='time')
@@ -574,8 +565,8 @@ def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recompu
574565

575566
data = dict()
576567

577-
data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1]
578-
data['vC'] = np.array(get_sorted(stats, type='voltage C', recomputed=recomputed, sortby='time'))[:, 1]
568+
data['cL'] = np.array([me[1][0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
569+
data['vC'] = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
579570
if use_adaptivity:
580571
data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1]
581572
data['e_em'] = np.array(

pySDC/projects/PinTSimE/estimation_check.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
125125
126126
Args:
127127
dt_list (list): list of considered (initial) step sizes
128-
problem (problem.__name__): Problem class used to consider (the class name)
129-
sweeper (sweeper.__name__): Sweeper used to solve (the class name)
128+
problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
129+
sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
130130
V_ref (np.float): reference value for the switch
131131
cwd: current working directory
132132
"""
@@ -240,11 +240,11 @@ def differences_around_switch(
240240
241241
Args:
242242
dt_list (list): list of considered (initial) step sizes
243-
problem (problem.__name__): Problem class used to consider (the class name)
243+
problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
244244
restarts_SE (list): Restarts for the solve only using the switch estimator
245245
restarts_adapt (list): Restarts for the solve of only using adaptivity
246246
restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity
247-
sweeper (sweeper.__name__): Sweeper used to solve (the class name)
247+
sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
248248
V_ref (np.float): reference value for the switch
249249
cwd: current working directory
250250
"""
@@ -284,16 +284,18 @@ def differences_around_switch(
284284
t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
285285
t_switch_SE_adapt = t_switch_SE_adapt[-1]
286286

287-
vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time')
288-
vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time')
289-
vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time')
290-
vC = get_sorted(stats, type='voltage C', sortby='time')
287+
vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False, sortby='time')]
288+
vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False, sortby='time')]
289+
vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False, sortby='time')]
290+
vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]
291291

292-
diff_SE, diff = [v[1] - V_ref[0] for v in vC_SE], [v[1] - V_ref[0] for v in vC]
293-
times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC]
292+
diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0]
293+
times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False, sortby='time')]
294+
times = [me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]
294295

295-
diff_adapt, diff_SE_adapt = [v[1] - V_ref[0] for v in vC_adapt], [v[1] - V_ref[0] for v in vC_SE_adapt]
296-
times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt]
296+
diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0]
297+
times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False, sortby='time')]
298+
times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False, sortby='time')]
297299

298300
for m in range(len(times_SE)):
299301
if np.round(times_SE[m], 15) == np.round(t_switch, 15):
@@ -387,8 +389,8 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
387389
388390
Args:
389391
dt_list (list): list of considered (initial) step sizes
390-
problem (problem.__name__): Problem class used to consider (the class name)
391-
sweeper (sweeper.__name__): Sweeper used to solve (the class name)
392+
problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
393+
sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
392394
V_ref (np.float): reference value for the switch
393395
cwd: current working directory
394396
"""
@@ -435,16 +437,18 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
435437
restart_adapt = np.array(get_sorted(stats_adapt, type='restart', recomputed=None, sortby='time'))
436438
restart_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='restart', recomputed=None, sortby='time'))
437439

438-
vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time')
439-
vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time')
440-
vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time')
441-
vC = get_sorted(stats, type='voltage C', sortby='time')
440+
vC_SE = [me[1][1] for me in get_sorted(stats_SE, type='u', recomputed=False, sortby='time')]
441+
vC_adapt = [me[1][1] for me in get_sorted(stats_adapt, type='u', recomputed=False, sortby='time')]
442+
vC_SE_adapt = [me[1][1] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False, sortby='time')]
443+
vC = [me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]
442444

443-
diff_SE, diff = [v[1] - V_ref[0] for v in vC_SE], [v[1] - V_ref[0] for v in vC]
444-
times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC]
445+
diff_SE, diff = vC_SE - V_ref[0], vC - V_ref[0]
446+
times_SE = [me[0] for me in get_sorted(stats_SE, type='u', recomputed=False, sortby='time')]
447+
times = [me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]
445448

446-
diff_adapt, diff_SE_adapt = [v[1] - V_ref[0] for v in vC_adapt], [v[1] - V_ref[0] for v in vC_SE_adapt]
447-
times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt]
449+
diff_adapt, diff_SE_adapt = vC_adapt - V_ref[0], vC_SE_adapt - V_ref[0]
450+
times_adapt = [me[0] for me in get_sorted(stats_adapt, type='u', recomputed=False, sortby='time')]
451+
times_SE_adapt = [me[0] for me in get_sorted(stats_SE_adapt, type='u', recomputed=False, sortby='time')]
448452

449453
if len(dt_list) > 1:
450454
ax_diffs[0, count_ax].set_title(r'$\Delta t$=%s' % dt_item)
@@ -534,8 +538,8 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
534538
Args:
535539
dt_list (list): list of considered (initial) step sizes
536540
maxiter (np.int): maximum number of iterations
537-
problem (problem.__name__): Problem class used to consider (the class name)
538-
sweeper (sweeper.__name__): Sweeper used to solve (the class name)
541+
problem (pySDC.core.Problem.ptype): Problem class used to consider (the class name)
542+
sweeper (pySDC.core.Sweeper.sweeper): Sweeper used to solve (the class name)
539543
cwd: current working directory
540544
"""
541545

0 commit comments

Comments
 (0)