Skip to content

Commit 91ab723

Browse files
committed
Added some more tests
1 parent 653e7e5 commit 91ab723

File tree

5 files changed

+51
-39
lines changed

5 files changed

+51
-39
lines changed

pySDC/projects/PinTSimE/battery_2condensators_model.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
99
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
1010
from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
11+
from pySDC.projects.PinTSimE.battery_model import get_recomputed
1112
from pySDC.projects.PinTSimE.piline_model import setup_mpl
1213
import pySDC.helpers.plot_helper as plt_helper
1314
from pySDC.core.Hooks import hooks
@@ -167,12 +168,14 @@ def main(use_switch_estimator=True):
167168
assert np.mean(niters) <= 4, "Mean number of iterations is too high, got %s" % np.mean(niters)
168169
f.close()
169170

170-
plot_voltages(description, use_switch_estimator)
171+
recomputed = False
172+
173+
plot_voltages(description, recomputed, use_switch_estimator)
171174

172175
return np.mean(niters)
173176

174177

175-
def plot_voltages(description, use_switch_estimator, cwd='./'):
178+
def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
176179
"""
177180
Routine to plot the numerical solution of the model
178181
"""
@@ -182,9 +185,9 @@ def plot_voltages(description, use_switch_estimator, cwd='./'):
182185
f.close()
183186

184187
# convert filtered statistics to list of iterations count, sorted by process
185-
cL = get_sorted(stats, type='current L', sortby='time')
186-
vC1 = get_sorted(stats, type='voltage C1', sortby='time')
187-
vC2 = get_sorted(stats, type='voltage C2', sortby='time')
188+
cL = get_sorted(stats, type='current L', recomputed=recomputed, sortby='time')
189+
vC1 = get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time')
190+
vC2 = get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time')
188191

189192
times = [v[0] for v in cL]
190193

@@ -195,7 +198,10 @@ def plot_voltages(description, use_switch_estimator, cwd='./'):
195198
ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$')
196199

197200
if use_switch_estimator:
198-
switches = get_sorted(stats, type='switch', sortby='time')
201+
switches = get_recomputed(stats, type='switch', sortby='time')
202+
203+
if recomputed is not None:
204+
assert len(switches) >= 1 and len(switches) >= 2, "No switches found"
199205
t_switches = [v[1] for v in switches]
200206

201207
for i in range(len(t_switches)):
@@ -222,11 +228,11 @@ def proof_assertions_description(description, use_switch_estimator):
222228
description['problem_params']['alpha'] > description['problem_params']['V_ref'][1]
223229
), 'Please set "alpha" greater than "V_ref2"'
224230

231+
assert type(description['problem_params']['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)'
232+
225233
assert description['problem_params']['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0'
226234
assert description['problem_params']['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0'
227235

228-
assert type(description['problem_params']['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)'
229-
230236
assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error'
231237
assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters'
232238
assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters'

pySDC/projects/PinTSimE/battery_model.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
132132
t0 = 0.0
133133
Tend = 0.3
134134

135+
assert dt < Tend, "Time step is too large for the time domain!"
136+
135137
# instantiate controller
136138
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
137139

@@ -167,7 +169,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
167169
min_iter = min(min_iter, item[1])
168170
max_iter = max(max_iter, item[1])
169171

170-
assert np.mean(niters) <= 5, "Mean number of iterations is too high, got %s" % np.mean(niters)
172+
assert np.mean(niters) <= 4, "Mean number of iterations is too high, got %s" % np.mean(niters)
171173
f.close()
172174

173175
return description
@@ -182,6 +184,7 @@ def run():
182184
dt = 1e-3
183185
problem_classes = [battery, battery_implicit]
184186
sweeper_classes = [imex_1st_order, generic_implicit]
187+
recomputed = False
185188
use_switch_estimator = [True]
186189
use_adaptivity = [True]
187190

@@ -196,10 +199,10 @@ def run():
196199
use_adaptivity=use_A,
197200
)
198201

199-
plot_voltages(description, problem.__name__, sweeper.__name__, use_SE, use_A)
202+
plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, use_A)
200203

201204

202-
def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adaptivity, cwd='./'):
205+
def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'):
203206
"""
204207
Routine to plot the numerical solution of the model
205208
"""
@@ -221,8 +224,10 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt
221224
ax.plot(times, [v[1] for v in vC], label=r'$v_C$')
222225

223226
if use_switch_estimator:
224-
val_switch = get_sorted(stats, type='switch', sortby='time')
225-
t_switch = [v[1] for v in val_switch]
227+
switches = get_recomputed(stats, type='switch', sortby='time')
228+
229+
assert len(switches) >= 1, 'No switches found!'
230+
t_switch = [v[1] for v in switches]
226231
ax.axvline(x=t_switch[-1], linestyle='--', linewidth=0.8, color='r', label='Switch')
227232

228233
if use_adaptivity:
@@ -241,7 +246,7 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt
241246

242247
fig.savefig('data/{}_model_solution_{}.png'.format(problem, sweeper), dpi=300, bbox_inches='tight')
243248
plt_helper.plt.close(fig)
244-
249+
245250
def get_recomputed(stats, type, sortby):
246251
"""
247252
Function that filters statistics after a recomputation

pySDC/projects/PinTSimE/estimation_check.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
9696
t0 = 0.0
9797
Tend = 0.3
9898

99+
assert dt < Tend, "Time step is too large for the time domain!"
100+
99101
# instantiate controller
100102
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
101103

@@ -156,6 +158,8 @@ def check(cwd='./'):
156158
V_ref=V_ref,
157159
)
158160

161+
assert len(get_recomputed(stats, type='switch', sortby='time')) >= 1, 'No switches found!'
162+
159163
fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__)
160164
f = open(fname, 'wb')
161165
dill.dump(stats, f)
@@ -225,8 +229,8 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
225229
stats_adapt = dill.load(f4)
226230
f4.close()
227231

228-
val_switch_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
229-
t_switch_SE_adapt = [v[1] for v in val_switch_SE_adapt]
232+
switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
233+
t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
230234
t_switch_SE_adapt = t_switch_SE_adapt[-1]
231235

232236
dt_SE_adapt_val = get_sorted(stats_SE_adapt, type='dt', recomputed=False)
@@ -349,12 +353,12 @@ def differences_around_switch(
349353
stats_adapt = dill.load(f4)
350354
f4.close()
351355

352-
val_switch_SE = get_recomputed(stats_SE, type='switch', sortby='time')
353-
t_switch = [v[1] for v in val_switch_SE]
356+
switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
357+
t_switch = [v[1] for v in switches_SE]
354358
t_switch = t_switch[-1] # battery has only one single switch
355359

356-
val_switch_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
357-
t_switch_SE_adapt = [v[1] for v in val_switch_SE_adapt]
360+
switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
361+
t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
358362
t_switch_SE_adapt = t_switch_SE_adapt[-1]
359363

360364
vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time')
@@ -493,12 +497,12 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
493497
stats_adapt = dill.load(f4)
494498
f4.close()
495499

496-
val_switch_SE = get_recomputed(stats_SE, type='switch', sortby='time')
497-
t_switch_SE = [v[1] for v in val_switch_SE]
500+
switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
501+
t_switch_SE = [v[1] for v in switches_SE]
498502
t_switch_SE = t_switch_SE[-1] # battery has only one single switch
499503

500-
val_switch_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
501-
t_switch_SE_adapt = [v[1] for v in val_switch_SE_adapt]
504+
switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
505+
t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
502506
t_switch_SE_adapt = t_switch_SE_adapt[-1]
503507

504508
dt_adapt = np.array(get_sorted(stats_adapt, type='dt', recomputed=False, sortby='time'))
@@ -654,12 +658,12 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
654658
times_adapt.append([v[0] for v in iter_counts_adapt_val])
655659
times.append([v[0] for v in iter_counts_val])
656660

657-
val_switch_SE = get_recomputed(stats_SE, type='switch', sortby='time')
658-
t_switch_SE = [v[1] for v in val_switch_SE]
661+
switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
662+
t_switch_SE = [v[1] for v in switches_SE]
659663
t_switches_SE.append(t_switch_SE[-1])
660664

661-
val_switch_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
662-
t_switch_SE_adapt = [v[1] for v in val_switch_SE_adapt]
665+
switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
666+
t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
663667
t_switches_SE_adapt.append(t_switch_SE_adapt[-1])
664668

665669
if len(dt_list) > 1:

pySDC/projects/PinTSimE/estimation_check_extended.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
99
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
1010
from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
11+
from pySDC.projects.PinTSimE.battery_model import get_recomputed
1112
from pySDC.projects.PinTSimE.piline_model import setup_mpl
1213
from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description
1314
import pySDC.helpers.plot_helper as plt_helper
@@ -128,6 +129,8 @@ def check(cwd='./'):
128129
for item in use_switch_estimator:
129130
stats, description = run(dt=dt_item, use_switch_estimator=item)
130131

132+
assert len(get_recomputed(stats, type='switch', sortby='time')) >= 1, 'No switches found!'
133+
131134
fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, item)
132135
f = open(fname, 'wb')
133136
dill.dump(stats, f)
@@ -159,13 +162,8 @@ def check(cwd='./'):
159162
stats_false = dill.load(f2)
160163
f2.close()
161164

162-
val_switch = get_sorted(stats_true, type='switch', sortby='time')
163-
#val_switch2 = get_sorted(stats_true, type='switch2', sortby='time')
164-
t_switch = [v[1] for v in val_switch]
165-
#t_switch2 = [v[1] for v in val_switch2]
166-
167-
#t_switch1 = t_switch1[-1]
168-
#t_switch2 = t_switch2[-1]
165+
switches = get_recomputed(stats_true, type='switch', sortby='time')
166+
t_switch = [v[1] for v in switches]
169167

170168
val_switch_all.append([t_switch[0], t_switch[1]])
171169

pySDC/projects/PinTSimE/switch_estimator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ def setup(self, controller, params, description):
3535
defaults = {
3636
'control_order': 100,
3737
'tol': description['level_params']['dt'],
38-
'coll_nodes_local': coll.nodes,
38+
'coll_nodes': coll.nodes,
3939
'switch_detected': False,
4040
'switch_detected_step': False,
4141
't_switch': None,
42-
'count_switches': 0,
4342
'dt_initial': description['level_params']['dt'],
4443
}
4544
return {**defaults, **params}
@@ -66,7 +65,7 @@ def get_new_step_size(self, controller, S):
6665

6766
if self.params.switch_detected:
6867
t_interp = [
69-
L.time + L.dt * self.params.coll_nodes_local[m] for m in range(len(self.params.coll_nodes_local))
68+
L.time + L.dt * self.params.coll_nodes[m] for m in range(len(self.params.coll_nodes))
7069
]
7170

7271
# only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem)
@@ -158,7 +157,7 @@ def post_step_processing(self, controller, S):
158157
self.params.t_switch = None
159158
self.params.switch_detected_step = False
160159

161-
dt_planned = L.status.dt_new if L.status.dt_new is not None else self.params.dt_initial
160+
dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt
162161
L.status.dt_new = dt_planned
163162

164163
super(SwitchEstimator, self).post_step_processing(controller, S)

0 commit comments

Comments
 (0)