Skip to content

Commit 303c2f0

Browse files
committed
Implemented test for basic configurations
1 parent ef6c24c commit 303c2f0

File tree

3 files changed

+35
-85
lines changed

3 files changed

+35
-85
lines changed

pySDC/implementations/hooks/log_solution.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,11 @@ def post_run(self, step, level_number):
266266
L = step.levels[level_number]
267267

268268
value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times]
269-
if value_exists and not self.allow_overwriting:
270-
raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!')
271-
self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
272-
self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
273-
self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
274-
type(self).counter = len(self.outfile.times)
269+
if not value_exists:
270+
self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
271+
self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
272+
self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
273+
type(self).counter = len(self.outfile.times)
275274

276275
@classmethod
277276
def load(cls, index):

pySDC/projects/GPU/configs/base_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,13 @@ def get_initial_condition(self, P, *args, restart_idx=0, **kwargs):
192192
class LogStats(ConvergenceController):
193193

194194
def get_stats_path(self, index=0):
195-
return f'{self.params.path}_{self.counter:06d}-stats.pickle'
195+
return f'{self.params.path}_{index:06d}-stats.pickle'
196196

197197
def merge_all_stats(self, controller):
198198
hook = self.params.hook
199199

200200
stats = {}
201-
for i in range(hook.counter):
201+
for i in range(hook.counter - 1):
202202
with open(self.get_stats_path(index=i), 'rb') as file:
203203
_stats = pickle.load(file)
204204
stats = {**stats, **_stats}
@@ -226,11 +226,8 @@ def post_step_processing(self, controller, S, **kwargs):
226226

227227
P = S.levels[0].prob
228228

229-
if self.counter == 0:
230-
self.counter = hook.counter - 1
231-
232229
while self.counter < hook.counter:
233-
path = self.get_stats_path(index=self.counter - 1)
230+
path = self.get_stats_path(index=hook.counter - 2)
234231
stats = controller.return_stats()
235232
store = True
236233
if hasattr(S.levels[0].sweep, 'comm') and S.levels[0].sweep.comm.rank > 0:
@@ -241,12 +238,15 @@ def post_step_processing(self, controller, S, **kwargs):
241238
with open(path, 'wb') as file:
242239
pickle.dump(stats, file)
243240
self.log(f'Stored stats in {path!r}', S)
241+
# print(stats)
244242
self.reset_stats(controller)
245243
self.counter = hook.counter
246244

247-
def post_run_processing(self, controller, *args, **kwargs):
245+
def post_run_processing(self, controller, S, **kwargs):
248246
stats = self.merge_all_stats(controller)
249247

248+
self.post_step_processing(controller, S, **kwargs)
249+
250250
def return_stats():
251251
return stats
252252

pySDC/projects/GPU/tests/test_configs.py

Lines changed: 23 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def create_directories():
3838
def test_run(tmpdir):
3939
from pySDC.projects.GPU.configs.base_config import get_config
4040
from pySDC.projects.GPU.run_experiment import run_experiment
41+
from pySDC.helpers.stats_helper import get_sorted
42+
from pySDC.helpers.fieldsIO import FieldsIO
43+
import pickle
44+
import numpy as np
4145

4246
args = {
4347
'config': 'RBC',
@@ -50,6 +54,7 @@ def test_run(tmpdir):
5054
'restart_idx': 0,
5155
}
5256
config = get_config(args)
57+
type(config).base_path = args['o']
5358

5459
def get_LogToFile(self, *args, **kwargs):
5560
if self.comms[1].rank > 0:
@@ -63,89 +68,39 @@ def get_LogToFile(self, *args, **kwargs):
6368
return LogToFile
6469

6570
type(config).get_LogToFile = get_LogToFile
71+
stats_path = f'{config.base_path}/data/{config.get_path()}-stats-whole-run.pickle'
72+
file_path = config.get_file_name()
6673

6774
# first run for a short time
6875
dt = config.get_description()['level_params']['dt']
6976
config.Tend = 2 * dt
7077
run_experiment(args, config)
7178

79+
# check data
80+
data = FieldsIO.fromFile(file_path)
81+
with open(stats_path, 'rb') as file:
82+
stats = pickle.load(file)
7283

73-
@pytest.mark.order(1)
74-
def test_run_experiment(restart_idx=0):
75-
from pySDC.projects.GPU.configs.base_config import Config
76-
from pySDC.projects.GPU.run_experiment import run_experiment, parse_args
77-
from pySDC.helpers.stats_helper import get_sorted
78-
import pickle
79-
import numpy as np
80-
81-
create_directories()
82-
83-
class VdPConfig(Config):
84-
sweeper_type = 'generic_implicit'
85-
Tend = 1
86-
87-
def get_description(self, *args, **kwargs):
88-
from pySDC.implementations.problem_classes.Van_der_Pol_implicit import vanderpol
89-
90-
desc = super().get_description(*args, **kwargs)
91-
desc['problem_class'] = vanderpol
92-
desc['problem_params'].pop('useGPU')
93-
desc['problem_params'].pop('comm')
94-
desc['sweeper_params']['num_nodes'] = 2
95-
desc['sweeper_params']['quad_type'] = 'RADAU-RIGHT'
96-
desc['sweeper_params']['QI'] = 'LU'
97-
desc['level_params']['dt'] = 0.1
98-
desc['step_params']['maxiter'] = 3
99-
return desc
100-
101-
def get_LogToFile(self, ranks=None):
102-
from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile
103-
104-
LogToFile.path = './data/'
105-
LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
106-
LogToFile.time_increment = 2e-1
107-
108-
def logging_condition(L):
109-
sweep = L.sweep
110-
if hasattr(sweep, 'comm'):
111-
if sweep.comm.rank == sweep.comm.size - 1:
112-
return True
113-
else:
114-
return False
115-
else:
116-
return True
117-
118-
LogToFile.logging_condition = logging_condition
119-
return LogToFile
84+
dts = get_sorted(stats, type='dt')
85+
assert len(dts) == len(data.times) - 1
86+
assert np.allclose(data.times, 0.1 * np.arange(3)), 'Did not record solutions at expected times'
12087

121-
args = {
122-
'procs': [1, 1, 1],
123-
'useGPU': False,
124-
'res': -1,
125-
'logger_level': 15,
126-
'restart_idx': restart_idx,
127-
'mode': 'run',
128-
}
129-
config = VdPConfig(args)
88+
# restart run
89+
args['restart_idx'] = -1
90+
config.Tend = 4 * dt
13091
run_experiment(args, config)
13192

132-
with open(f'data/{config.get_path()}-stats-whole-run.pickle', 'rb') as file:
93+
# check data
94+
data = FieldsIO.fromFile(file_path)
95+
with open(stats_path, 'rb') as file:
13396
stats = pickle.load(file)
97+
dts = get_sorted(stats, type='dt')
13498

135-
k_Newton = get_sorted(stats, type='work_newton')
136-
assert len(k_Newton) == 10
137-
assert sum([me[1] for me in k_Newton]) == 91
138-
139-
140-
@pytest.mark.order(2)
141-
def test_restart():
142-
test_run_experiment(3)
99+
assert len(dts) == len(data.times) - 1
100+
assert np.allclose(data.times, 0.1 * np.arange(5)), 'Did not record solutions at expected times after restart'
143101

144102

145103
if __name__ == '__main__':
146-
test_run('.')
147-
exit()
148-
149104
import argparse
150105

151106
parser = argparse.ArgumentParser()
@@ -155,9 +110,5 @@ def test_restart():
155110

156111
if args.test == 'get_comms':
157112
test_get_comms(False)
158-
elif args.test == 'run_experiment':
159-
test_run_experiment()
160-
elif args.test == 'restart':
161-
test_restart()
162113
else:
163114
raise NotImplementedError

0 commit comments

Comments
 (0)