Skip to content

Commit 7ded866

Browse files
Added hook for storing solution to file after regular time increment (#486)
1 parent 2848895 commit 7ded866

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

pySDC/implementations/hooks/log_solution.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ class LogToFile(Hooks):
8181
8282
Keep in mind that the hook will overwrite files without warning!
8383
You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the
84-
index associated with individual files by giving a different lambda function ``format_index`` class attribute. This
85-
lambda should accept one index and return one string.
84+
index associated with individual files by giving a different function ``format_index`` class attribute. This should
85+
accept one index and return one string.
8686
8787
You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively.
8888
8989
Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution.
90-
Of course, if you are not using numpy, you need to change this. Again, this is a lambda accepting the level.
90+
Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level.
9191
9292
After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to
9393
directly load the solution at a given index. Just configure the hook like you did when you recorded the data
@@ -99,6 +99,7 @@ class LogToFile(Hooks):
9999

100100
path = None
101101
file_name = 'solution'
102+
counter = 0
102103

103104
def logging_condition(L):
104105
return True
@@ -111,7 +112,6 @@ def format_index(index):
111112

112113
def __init__(self):
113114
super().__init__()
114-
self.counter = 0
115115

116116
if self.path is None:
117117
raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!')
@@ -124,20 +124,41 @@ def __init__(self):
124124
if not os.path.isdir(self.path):
125125
os.mkdir(self.path)
126126

127-
def post_step(self, step, level_number):
127+
def log_to_file(self, step, level_number, condition, process_solution=None):
128128
if level_number > 0:
129129
return None
130130

131131
L = step.levels[level_number]
132132

133-
if type(self).logging_condition(L):
133+
if condition:
134134
path = self.get_path(self.counter)
135-
data = type(self).process_solution(L)
135+
136+
if process_solution:
137+
data = process_solution(L)
138+
else:
139+
data = type(self).process_solution(L)
136140

137141
with open(path, 'wb') as file:
138142
pickle.dump(data, file)
143+
self.logger.info(f'Stored file {path!r}')
144+
145+
type(self).counter += 1
146+
147+
def post_step(self, step, level_number):
148+
L = step.levels[level_number]
149+
self.log_to_file(step, level_number, type(self).logging_condition(L))
150+
151+
def pre_run(self, step, level_number):
152+
L = step.levels[level_number]
153+
L.uend = L.u[0]
154+
155+
def process_solution(L):
156+
return {
157+
**type(self).process_solution(L),
158+
't': L.time,
159+
}
139160

140-
self.counter += 1
161+
self.log_to_file(step, level_number, True, process_solution=process_solution)
141162

142163
@classmethod
143164
def get_path(cls, index):
@@ -148,3 +169,22 @@ def load(cls, index):
148169
path = cls.get_path(index)
149170
with open(path, 'rb') as file:
150171
return pickle.load(file)
172+
173+
174+
class LogToFileAfterXs(LogToFile):
175+
r'''
176+
Log to file after certain amount of time has passed instead of after every step
177+
'''
178+
179+
time_increment = 0
180+
t_next_log = 0
181+
182+
def post_step(self, step, level_number):
183+
L = step.levels[level_number]
184+
185+
if self.t_next_log == 0:
186+
self.t_next_log = self.time_increment
187+
188+
if L.time + L.dt >= self.t_next_log and not step.status.restart:
189+
super().post_step(step, level_number)
190+
self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment

pySDC/tests/test_hooks/test_log_to_file.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def run(hook, Tend=0):
3232
u0 = prob.u_exact(0)
3333

3434
_, stats = controller.run(u0, 0, Tend)
35-
return stats
35+
return u0, stats
3636

3737

3838
@pytest.mark.base
@@ -68,8 +68,8 @@ def test_logging():
6868
LogToFile.path = path
6969
Tend = 2
7070

71-
stats = run([LogToFile, LogSolution], Tend=Tend)
72-
u = get_sorted(stats, type='u')
71+
u0, stats = run([LogToFile, LogSolution], Tend=Tend)
72+
u = [(0.0, u0)] + get_sorted(stats, type='u')
7373

7474
u_file = []
7575
for i in range(len(u)):

0 commit comments

Comments
 (0)