Skip to content

Commit 8ec74ef

Browse files
committed
Hardcoded references for battery_2condensators
1 parent d78106a commit 8ec74ef

File tree

3 files changed

+134
-6
lines changed

3 files changed

+134
-6
lines changed

pySDC/projects/PinTSimE/battery_2condensators_model.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators
88
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
99
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
10-
from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
1110
from pySDC.projects.PinTSimE.battery_model import get_recomputed
1211
from pySDC.projects.PinTSimE.piline_model import setup_mpl
1312
import pySDC.helpers.plot_helper as plt_helper
@@ -80,6 +79,8 @@ def main(use_switch_estimator=True):
8079
level_params['restol'] = -1
8180
level_params['dt'] = 1e-2
8281

82+
assert level_params['dt'] == 1e-2, 'Error! Do not use the time step dt != 1e-2!'
83+
8384
# initialize sweeper parameters
8485
sweeper_params = dict()
8586
sweeper_params['quad_type'] = 'LOBATTO'
@@ -122,7 +123,6 @@ def main(use_switch_estimator=True):
122123
description['sweeper_params'] = sweeper_params # pass sweeper parameters
123124
description['level_params'] = level_params # pass level parameters
124125
description['step_params'] = step_params
125-
description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class
126126

127127
if use_switch_estimator:
128128
description['convergence_controllers'] = convergence_controllers
@@ -176,6 +176,8 @@ def main(use_switch_estimator=True):
176176

177177
recomputed = False
178178

179+
check_solution(stats, use_switch_estimator)
180+
179181
plot_voltages(description, recomputed, use_switch_estimator)
180182

181183
return description
@@ -228,6 +230,68 @@ def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
228230
plt_helper.plt.close(fig)
229231

230232

233+
def check_solution(stats, use_switch_estimator):
234+
"""
235+
Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
236+
237+
Args:
238+
stats (dict): Raw statistics from a controller run
239+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
240+
"""
241+
242+
data = get_data_dict(stats, use_switch_estimator)
243+
244+
if use_switch_estimator:
245+
msg = 'Error when using the switch estimator for battery_2condensators:'
246+
expected = {
247+
'cL': 1.1597046304825833,
248+
'vC1': 1.000472118416924,
249+
'vC2': 1.000226101799117,
250+
'switch1': 1.6094379124373626,
251+
'switch2': 3.2184040405613974,
252+
'restarts': 2.0,
253+
}
254+
255+
got = {
256+
'cL': data['cL'][-1],
257+
'vC1': data['vC1'][-1],
258+
'vC2': data['vC2'][-1],
259+
'switch1': data['switch1'],
260+
'switch2': data['switch2'],
261+
'restarts': data['restarts'],
262+
}
263+
264+
for key in expected.keys():
265+
assert np.isclose(
266+
expected[key], got[key], rtol=1e-4
267+
), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
268+
269+
270+
def get_data_dict(stats, use_switch_estimator, recomputed=False):
271+
"""
272+
Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
273+
Based on @brownbaerchen's get_data function.
274+
275+
Args:
276+
stats (dict): Raw statistics from a controller run
277+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
278+
recomputed (bool): flag if the values after a restart are used or before
279+
280+
Return:
281+
data (dict): contains all information as the statistics dict
282+
"""
283+
284+
data = dict()
285+
data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1]
286+
data['vC1'] = np.array(get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time'))[:, 1]
287+
data['vC2'] = np.array(get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time'))[:, 1]
288+
data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1]
289+
data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1]
290+
data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
291+
292+
return data
293+
294+
231295
def proof_assertions_description(description, use_switch_estimator):
232296
"""
233297
Function to proof the assertions (function to get cleaner code)

pySDC/projects/PinTSimE/battery_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,10 @@ def check_solution(stats, problem, use_adaptivity, use_switch_estimator):
336336
), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
337337

338338

339-
def get_data_dict(stats, use_adaptivity, use_switch_estimator, recomputed=False):
339+
def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recomputed=False):
340340
"""
341341
Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
342+
Based on @brownbaerchen's get_data function.
342343
343344
Args:
344345
stats (dict): Raw statistics from a controller run

pySDC/projects/PinTSimE/estimation_check_extended.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
99
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
1010
from pySDC.projects.PinTSimE.battery_model import get_recomputed
11+
from pySDC.projects.PinTSimE.battery_2condensators_model import get_data_dict
1112
from pySDC.projects.PinTSimE.piline_model import setup_mpl
1213
from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description
1314
import pySDC.helpers.plot_helper as plt_helper
@@ -33,6 +34,10 @@ def run(dt, use_switch_estimator=True):
3334
level_params['restol'] = -1
3435
level_params['dt'] = dt
3536

37+
assert (
38+
dt == 4e-1 or dt == 4e-2 or dt == 4e-3
39+
), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!"
40+
3641
# initialize sweeper parameters
3742
sweeper_params = dict()
3843
sweeper_params['quad_type'] = 'LOBATTO'
@@ -143,9 +148,10 @@ def check(cwd='./'):
143148
stats, description = run(dt=dt_item, use_switch_estimator=use_SE)
144149

145150
if use_SE:
146-
assert (
147-
len(get_recomputed(stats, type='switch', sortby='time')) >= 1
148-
), 'No switches found for dt={}!'.format(dt_item)
151+
switches = get_recomputed(stats, type='switch', sortby='time')
152+
assert (len(switches) >= 2), f"Expected at least 2 switches for dt: {dt_item}, got {len(switches)}!"
153+
154+
check_solution(stats, dt_item, use_SE)
149155

150156
fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, use_SE)
151157
f = open(fname, 'wb')
@@ -299,5 +305,62 @@ def check(cwd='./'):
299305
plt_helper.plt.close(fig2)
300306

301307

308+
def check_solution(stats, dt, use_switch_estimator):
309+
"""
310+
Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
311+
312+
Args:
313+
stats (dict): Raw statistics from a controller run
314+
dt (float): initial time step
315+
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
316+
"""
317+
318+
data = get_data_dict(stats, use_switch_estimator)
319+
320+
if use_switch_estimator:
321+
msg = f'Error when using the switch estimator for battery_2condensators for dt={dt:.1e}:'
322+
if dt == 4e-1:
323+
expected = {
324+
'cL': 1.1556732037544801,
325+
'vC1': 1.002239522400514,
326+
'vC2': 1.0000329223874842,
327+
'switch1': 1.607586793484041,
328+
'switch2': 3.216645438176962,
329+
'restarts': 2.0,
330+
}
331+
elif dt == 4e-2:
332+
expected = {
333+
'cL': 1.1492603893091364,
334+
'vC1': 1.0005011122241925,
335+
'vC2': 1.000015039670507,
336+
'switch1': 1.6094074085596919,
337+
'switch2': 3.2183750611596893,
338+
'restarts': 2.0,
339+
}
340+
elif dt == 4e-3:
341+
expected = {
342+
'cL': 1.1476283937778273,
343+
'vC1': 1.0001336962511904,
344+
'vC2': 1.0000182217245925,
345+
'switch1': 1.6093728710270498,
346+
'switch2': 3.218742142249058,
347+
'restarts': 2.0,
348+
}
349+
350+
got = {
351+
'cL': data['cL'][-1],
352+
'vC1': data['vC1'][-1],
353+
'vC2': data['vC2'][-1],
354+
'switch1': data['switch1'],
355+
'switch2': data['switch2'],
356+
'restarts': data['restarts'],
357+
}
358+
359+
for key in expected.keys():
360+
assert np.isclose(
361+
expected[key], got[key], rtol=1e-4
362+
), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
363+
364+
302365
if __name__ == "__main__":
303366
check()

0 commit comments

Comments
 (0)