Skip to content

Commit d77e431

Browse files
Added hook for logging to file (#410)
1 parent 745b027 commit d77e431

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

pySDC/implementations/hooks/log_solution.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from pySDC.core.Hooks import hooks
2+
import pickle
3+
import os
4+
import numpy as np
25

36

47
class LogSolution(hooks):
@@ -63,3 +66,79 @@ def post_iteration(self, step, level_number):
6366
type='u',
6467
value=L.uend,
6568
)
69+
70+
71+
class LogToFile(hooks):
72+
r"""
73+
Hook for logging the solution to file after the step using pickle.
74+
75+
Please configure the hook to your liking by manipulating class attributes.
76+
You must set a custom path to a directory like so:
77+
78+
```
79+
LogToFile.path = '/my/directory/'
80+
```
81+
82+
Keep in mind that the hook will overwrite files without warning!
83+
You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the
84+
index associated with individual files by giving a different lambda function ``format_index`` class attribute. This
85+
lambda should accept one index and return one string.
86+
87+
You can also give a custom ``logging_condition`` lambda, accepting the current level if you want to log selectively.
88+
89+
Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution.
90+
Of course, if you are not using numpy, you need to change this. Again, this is a lambda accepting the level.
91+
92+
After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to
93+
directly load the solution at a given index. Just configure the hook like you did when you recorded the data
94+
beforehand.
95+
96+
Finally, be aware that using this hook with MPI parallel runs may lead to different tasks overwriting files. Make
97+
sure to give a different `file_name` for each task that writes files.
98+
"""
99+
100+
path = None
101+
file_name = 'solution'
102+
logging_condition = lambda L: True
103+
process_solution = lambda L: {'t': L.time + L.dt, 'u': L.uend.view(np.ndarray)}
104+
format_index = lambda index: f'{index:06d}'
105+
106+
def __init__(self):
107+
super().__init__()
108+
self.counter = 0
109+
110+
if self.path is None:
111+
raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!')
112+
113+
if os.path.isfile(self.path):
114+
raise ValueError(
115+
f'{self.path!r} is not a valid path to log to because a file of the same name exists. Please supply a directory'
116+
)
117+
118+
if not os.path.isdir(self.path):
119+
os.mkdir(self.path)
120+
121+
def post_step(self, step, level_number):
122+
if level_number > 0:
123+
return None
124+
125+
L = step.levels[level_number]
126+
127+
if type(self).logging_condition(L):
128+
path = self.get_path(self.counter)
129+
data = type(self).process_solution(L)
130+
131+
with open(path, 'wb') as file:
132+
pickle.dump(data, file)
133+
134+
self.counter += 1
135+
136+
@classmethod
137+
def get_path(cls, index):
138+
return f'{cls.path}/{cls.file_name}_{cls.format_index(index)}.pickle'
139+
140+
@classmethod
141+
def load(cls, index):
142+
path = cls.get_path(index)
143+
with open(path, 'rb') as file:
144+
return pickle.load(file)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
3+
4+
def run(hook, Tend=0):
5+
from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d
6+
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
7+
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
8+
9+
level_params = {'dt': 1.0e-1}
10+
11+
sweeper_params = {
12+
'num_nodes': 1,
13+
'quad_type': 'GAUSS',
14+
}
15+
16+
description = {
17+
'level_params': level_params,
18+
'sweeper_class': generic_implicit,
19+
'problem_class': testequation0d,
20+
'sweeper_params': sweeper_params,
21+
'problem_params': {},
22+
'step_params': {'maxiter': 1},
23+
}
24+
25+
controller_params = {
26+
'hook_class': hook,
27+
'logger_level': 30,
28+
}
29+
controller = controller_nonMPI(1, controller_params, description)
30+
if Tend > 0:
31+
prob = controller.MS[0].levels[0].prob
32+
u0 = prob.u_exact(0)
33+
34+
_, stats = controller.run(u0, 0, Tend)
35+
return stats
36+
37+
38+
@pytest.mark.base
39+
def test_errors():
40+
from pySDC.implementations.hooks.log_solution import LogToFile
41+
import os
42+
43+
with pytest.raises(ValueError):
44+
run(LogToFile)
45+
46+
LogToFile.path = os.getcwd()
47+
run(LogToFile)
48+
49+
path = f'{os.getcwd()}/tmp'
50+
LogToFile.path = path
51+
run(LogToFile)
52+
os.path.isdir(path)
53+
54+
with pytest.raises(ValueError):
55+
LogToFile.path = __file__
56+
run(LogToFile)
57+
58+
59+
@pytest.mark.base
60+
def test_logging():
61+
from pySDC.implementations.hooks.log_solution import LogToFile, LogSolution
62+
from pySDC.helpers.stats_helper import get_sorted
63+
import os
64+
import pickle
65+
import numpy as np
66+
67+
path = f'{os.getcwd()}/tmp'
68+
LogToFile.path = path
69+
Tend = 2
70+
71+
stats = run([LogToFile, LogSolution], Tend=Tend)
72+
u = get_sorted(stats, type='u')
73+
74+
u_file = []
75+
for i in range(len(u)):
76+
data = LogToFile.load(i)
77+
u_file += [(data['t'], data['u'])]
78+
79+
for us, uf in zip(u, u_file):
80+
assert us[0] == uf[0]
81+
assert np.allclose(us[1], uf[1])
82+
83+
84+
if __name__ == '__main__':
85+
test_logging()

0 commit comments

Comments
 (0)