|
7 | 7 | from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators |
8 | 8 | from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order |
9 | 9 | from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI |
10 | | -from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh |
11 | 10 | from pySDC.projects.PinTSimE.battery_model import get_recomputed |
12 | 11 | from pySDC.projects.PinTSimE.piline_model import setup_mpl |
13 | 12 | import pySDC.helpers.plot_helper as plt_helper |
@@ -80,6 +79,8 @@ def main(use_switch_estimator=True): |
80 | 79 | level_params['restol'] = -1 |
81 | 80 | level_params['dt'] = 1e-2 |
82 | 81 |
|
| 82 | + assert level_params['dt'] == 1e-2, 'Error! Do not use the time step dt != 1e-2!' |
| 83 | + |
83 | 84 | # initialize sweeper parameters |
84 | 85 | sweeper_params = dict() |
85 | 86 | sweeper_params['quad_type'] = 'LOBATTO' |
@@ -122,7 +123,6 @@ def main(use_switch_estimator=True): |
122 | 123 | description['sweeper_params'] = sweeper_params # pass sweeper parameters |
123 | 124 | description['level_params'] = level_params # pass level parameters |
124 | 125 | description['step_params'] = step_params |
125 | | - description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class |
126 | 126 |
|
127 | 127 | if use_switch_estimator: |
128 | 128 | description['convergence_controllers'] = convergence_controllers |
@@ -176,6 +176,8 @@ def main(use_switch_estimator=True): |
176 | 176 |
|
177 | 177 | recomputed = False |
178 | 178 |
|
| 179 | + check_solution(stats, use_switch_estimator) |
| 180 | + |
179 | 181 | plot_voltages(description, recomputed, use_switch_estimator) |
180 | 182 |
|
181 | 183 | return description |
@@ -228,6 +230,68 @@ def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'): |
228 | 230 | plt_helper.plt.close(fig) |
229 | 231 |
|
230 | 232 |
|
| 233 | +def check_solution(stats, use_switch_estimator): |
| 234 | + """ |
| 235 | + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. |
| 236 | +
|
| 237 | + Args: |
| 238 | + stats (dict): Raw statistics from a controller run |
| 239 | + use_switch_estimator (bool): flag if the switch estimator wants to be used or not |
| 240 | + """ |
| 241 | + |
| 242 | + data = get_data_dict(stats, use_switch_estimator) |
| 243 | + |
| 244 | + if use_switch_estimator: |
| 245 | + msg = 'Error when using the switch estimator for battery_2condensators:' |
| 246 | + expected = { |
| 247 | + 'cL': 1.1597046304825833, |
| 248 | + 'vC1': 1.000472118416924, |
| 249 | + 'vC2': 1.000226101799117, |
| 250 | + 'switch1': 1.6094379124373626, |
| 251 | + 'switch2': 3.2184040405613974, |
| 252 | + 'restarts': 2.0, |
| 253 | + } |
| 254 | + |
| 255 | + got = { |
| 256 | + 'cL': data['cL'][-1], |
| 257 | + 'vC1': data['vC1'][-1], |
| 258 | + 'vC2': data['vC2'][-1], |
| 259 | + 'switch1': data['switch1'], |
| 260 | + 'switch2': data['switch2'], |
| 261 | + 'restarts': data['restarts'], |
| 262 | + } |
| 263 | + |
| 264 | + for key in expected.keys(): |
| 265 | + assert np.isclose( |
| 266 | + expected[key], got[key], rtol=1e-4 |
| 267 | + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' |
| 268 | + |
| 269 | + |
| 270 | +def get_data_dict(stats, use_switch_estimator, recomputed=False): |
| 271 | + """ |
| 272 | + Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. |
| 273 | + Based on @brownbaerchen's get_data function. |
| 274 | +
|
| 275 | + Args: |
| 276 | + stats (dict): Raw statistics from a controller run |
| 277 | + use_switch_estimator (bool): flag if the switch estimator wants to be used or not |
| 278 | + recomputed (bool): flag if the values after a restart are used or before |
| 279 | +
|
| 280 | + Return: |
| 281 | + data (dict): contains all information as the statistics dict |
| 282 | + """ |
| 283 | + |
| 284 | + data = dict() |
| 285 | + data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1] |
| 286 | + data['vC1'] = np.array(get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time'))[:, 1] |
| 287 | + data['vC2'] = np.array(get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time'))[:, 1] |
| 288 | + data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1] |
| 289 | + data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1] |
| 290 | + data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) |
| 291 | + |
| 292 | + return data |
| 293 | + |
| 294 | + |
231 | 295 | def proof_assertions_description(description, use_switch_estimator): |
232 | 296 | """ |
233 | 297 | Function to proof the assertions (function to get cleaner code) |
|
0 commit comments