Skip to content

Commit 2e45cb8

Browse files
committed
Generate_description()+controller_run()+renaming
1 parent 4e61c4c commit 2e45cb8

File tree

8 files changed

+258
-355
lines changed

8 files changed

+258
-355
lines changed

pySDC/implementations/problem_classes/Battery.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
2626
"""
2727

2828
# these parameters will be used later, so assert their existence
29-
essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
29+
essential_keys = ['ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
3030
for key in essential_keys:
3131
if key not in problem_params:
3232
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
3333
raise ParameterError(msg)
3434

35-
problem_params['nvars'] = problem_params['ncondensators'] + 1
35+
problem_params['nvars'] = problem_params['ncapacitors'] + 1
3636

3737
# invoke super init, passing number of dofs, dtype_u and dtype_f
3838
super(battery, self).__init__(
@@ -156,13 +156,13 @@ def count_switches(self):
156156
class battery_implicit(battery):
157157
def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
158158

159-
essential_keys = ['newton_maxiter', 'newton_tol', 'ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
159+
essential_keys = ['newton_maxiter', 'newton_tol', 'ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
160160
for key in essential_keys:
161161
if key not in problem_params:
162162
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
163163
raise ParameterError(msg)
164164

165-
problem_params['nvars'] = problem_params['ncondensators'] + 1
165+
problem_params['nvars'] = problem_params['ncapacitors'] + 1
166166

167167
# invoke super init, passing number of dofs, dtype_u and dtype_f
168168
super(battery_implicit, self).__init__(
@@ -289,13 +289,13 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
289289
"""
290290

291291
# these parameters will be used later, so assert their existence
292-
essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
292+
essential_keys = ['ncapacitors', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
293293
for key in essential_keys:
294294
if key not in problem_params:
295295
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
296296
raise ParameterError(msg)
297297

298-
n = problem_params['ncondensators']
298+
n = problem_params['ncapacitors']
299299
problem_params['nvars'] = n + 1
300300

301301
# invoke super init, passing number of dofs, dtype_u and dtype_f

pySDC/projects/PinTSimE/battery_2capacitors_model.py

Lines changed: 57 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from pySDC.implementations.problem_classes.Battery import battery_n_capacitors
88
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
99
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
10-
from pySDC.projects.PinTSimE.battery_model import get_recomputed, proof_assertions_description
10+
from pySDC.projects.PinTSimE.battery_model import (
11+
controller_run,
12+
generate_description,
13+
get_recomputed,
14+
proof_assertions_description,
15+
)
1116
from pySDC.projects.PinTSimE.piline_model import setup_mpl
1217
import pySDC.helpers.plot_helper as plt_helper
1318
from pySDC.core.Hooks import hooks
@@ -63,112 +68,60 @@ def post_step(self, step, level_number):
6368
)
6469

6570

66-
def main(use_switch_estimator=True):
71+
def run():
6772
"""
68-
A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators
69-
70-
Args:
71-
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
72-
73-
Returns:
74-
description (dict): contains all information for a controller run
73+
Executes the simulation for the battery model using the IMEX sweeper and plot the results
74+
as <problem_class>_model_solution_<sweeper_class>.png
7575
"""
7676

77-
# initialize level parameters
78-
level_params = dict()
79-
level_params['restol'] = -1
80-
level_params['dt'] = 1e-2
81-
82-
assert level_params['dt'] == 1e-2, 'Error! Do not use the time step dt != 1e-2!'
83-
84-
# initialize sweeper parameters
85-
sweeper_params = dict()
86-
sweeper_params['quad_type'] = 'LOBATTO'
87-
sweeper_params['num_nodes'] = 5
88-
# sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
89-
sweeper_params['initial_guess'] = 'spread'
90-
91-
# initialize problem parameters
92-
problem_params = dict()
93-
problem_params['ncondensators'] = 2
94-
problem_params['Vs'] = 5.0
95-
problem_params['Rs'] = 0.5
96-
problem_params['C'] = np.array([1.0, 1.0])
97-
problem_params['R'] = 1.0
98-
problem_params['L'] = 1.0
99-
problem_params['alpha'] = 5.0
100-
problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2]
101-
102-
# initialize step parameters
103-
step_params = dict()
104-
step_params['maxiter'] = 4
105-
106-
# initialize controller parameters
107-
controller_params = dict()
108-
controller_params['logger_level'] = 30
109-
controller_params['hook_class'] = log_data
110-
111-
# convergence controllers
112-
convergence_controllers = dict()
113-
if use_switch_estimator:
114-
switch_estimator_params = {}
115-
convergence_controllers[SwitchEstimator] = switch_estimator_params
116-
117-
# fill description dictionary for easy step instantiation
118-
description = dict()
119-
description['problem_class'] = battery_n_capacitors # pass problem class
120-
description['problem_params'] = problem_params # pass problem parameters
121-
description['sweeper_class'] = imex_1st_order # pass sweeper
122-
description['sweeper_params'] = sweeper_params # pass sweeper parameters
123-
description['level_params'] = level_params # pass level parameters
124-
description['step_params'] = step_params
125-
126-
if use_switch_estimator:
127-
description['convergence_controllers'] = convergence_controllers
128-
129-
proof_assertions_description(description, False, use_switch_estimator)
130-
131-
# set time parameters
77+
dt = 1e-2
13278
t0 = 0.0
13379
Tend = 3.5
13480

135-
# instantiate controller
136-
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
81+
problem_classes = [battery_n_capacitors]
82+
sweeper_classes = [imex_1st_order]
13783

138-
# get initial values on finest level
139-
P = controller.MS[0].levels[0].prob
140-
uinit = P.u_exact(t0)
84+
ncapacitors = 2
85+
alpha = 5.0
86+
V_ref = np.array([1.0, 1.0])
87+
C = np.array([1.0, 1.0])
14188

142-
# call main function to get things done...
143-
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
89+
recomputed = False
90+
use_switch_estimator = [True]
14491

145-
Path("data").mkdir(parents=True, exist_ok=True)
146-
fname = 'data/battery_2condensators.dat'
147-
f = open(fname, 'wb')
148-
dill.dump(stats, f)
149-
f.close()
92+
for problem, sweeper in zip(problem_classes, sweeper_classes):
93+
for use_SE in use_switch_estimator:
94+
description, controller_params = generate_description(
95+
dt, problem, sweeper, log_data, False, use_SE, ncapacitors, alpha, V_ref, C
96+
)
15097

151-
recomputed = False
98+
# Assertions
99+
proof_assertions_description(description, False, use_SE)
100+
101+
proof_assertions_time(dt, Tend, V_ref, alpha)
152102

153-
check_solution(stats, level_params['dt'], use_switch_estimator)
103+
stats = controller_run(description, controller_params, False, use_SE, t0, Tend)
154104

155-
plot_voltages(description, recomputed, use_switch_estimator)
105+
check_solution(stats, dt, use_SE)
156106

157-
return description
107+
plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, False)
158108

159109

160-
def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
110+
def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'):
161111
"""
162112
Routine to plot the numerical solution of the model
163113
164114
Args:
165115
description(dict): contains all information for a controller run
116+
problem (problem_class.__name__): problem class that wants to be simulated
117+
sweeper (sweeper_class.__name__): sweeper class for solving the problem class numerically
166118
recomputed (bool): flag if the values after a restart are used or before
167119
use_switch_estimator (bool): flag if the switch estimator wants to be used or not
120+
use_adaptivity (bool): flag if adaptivity wants to be used or not
168121
cwd: current working directory
169122
"""
170123

171-
f = open(cwd + 'data/battery_2condensators.dat', 'rb')
124+
f = open(cwd + 'data/{}_{}_USE{}_USA{}.dat'.format(problem, sweeper, use_switch_estimator, use_adaptivity), 'rb')
172125
stats = dill.load(f)
173126
f.close()
174127

@@ -199,7 +152,7 @@ def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
199152
ax.set_xlabel('Time')
200153
ax.set_ylabel('Energy')
201154

202-
fig.savefig('data/battery_2condensators_model_solution.png', dpi=300, bbox_inches='tight')
155+
fig.savefig('data/battery_2capacitors_model_solution.png', dpi=300, bbox_inches='tight')
203156
plt_helper.plt.close(fig)
204157

205158

@@ -300,5 +253,25 @@ def get_data_dict(stats, use_switch_estimator, recomputed=False):
300253
return data
301254

302255

256+
def proof_assertions_time(dt, Tend, V_ref, alpha):
257+
"""
258+
Function to proof the assertions regarding the time domain (in combination with the specific problem):
259+
260+
Args:
261+
dt (float): time step for computation
262+
Tend (float): end time
263+
V_ref (np.ndarray): Reference values (problem parameter)
264+
alpha (np.float): Multiple used for initial conditions (problem_parameter)
265+
"""
266+
267+
assert (
268+
Tend == 3.5 and V_ref[0] == 1.0 and V_ref[1] == 1.0 and alpha == 5.0
269+
), "Error! Do not use other parameters for V_ref[:] != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!"
270+
271+
assert (
272+
dt == 1e-2 or dt == 4e-1 or dt == 4e-2 or dt == 4e-3
273+
), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!"
274+
275+
303276
if __name__ == "__main__":
304-
main()
277+
run()

0 commit comments

Comments
 (0)