Skip to content

Commit a17c921

Browse files
committed
add 'save' and 'load' to the learners and periodic saving to the Runner
1 parent 1577a17 commit a17c921

File tree

13 files changed

+553
-20
lines changed

13 files changed

+553
-20
lines changed

adaptive/learner/average_learner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,9 @@ def plot(self):
125125
num_bins = int(max(5, sqrt(self.npoints)))
126126
vals = hv.Points(vals)
127127
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
128+
129+
def _get_data(self):
130+
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
131+
132+
def _set_data(self, data):
133+
self.data, self.npoints, self.sum_f, self.sum_f_sq = data

adaptive/learner/balancing_learner.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import suppress
44
from functools import partial
55
from operator import itemgetter
6+
import os.path
67

78
import numpy as np
89

@@ -302,3 +303,75 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
302303
learner = learner_type(function=partial(f, **combo), **learner_kwargs)
303304
learners.append(learner)
304305
return cls(learners, cdims=arguments)
306+
307+
def save(self, folder, compress=True):
308+
"""Save the data of the child learners into pickle files
309+
in a directory.
310+
311+
Parameters
312+
----------
313+
folder : str
314+
Directory in which the learners's data will be saved.
315+
compress : bool, default True
316+
Compress the data upon saving using 'gzip'. When saving
317+
using compression, one must load it with compression too.
318+
319+
Notes
320+
-----
321+
The child learners need to have a 'fname' attribute in order to use
322+
this method.
323+
324+
Example
325+
-------
326+
>>> def combo_fname(val):
327+
... return '__'.join([f'{k}_{v}.p' for k, v in val.items()])
328+
...
329+
... def f(x, a, b): return a * x**2 + b
330+
...
331+
>>> learners = []
332+
>>> for combo in adaptive.utils.named_product(a=[1, 2], b=[1]):
333+
... l = Learner1D(functools.partial(f, combo=combo))
334+
... l.fname = combo_fname(combo) # 'a_1__b_1.p', 'a_2__b_1.p' etc.
335+
... learners.append(l)
336+
... learner = BalancingLearner(learners)
337+
... # Run the learner
338+
... runner = adaptive.Runner(learner)
339+
... # Then save
340+
... learner.save('data_folder') # use 'load' in the same way
341+
"""
342+
if len(self.learners) != len(set(l.fname for l in self.learners)):
343+
raise RuntimeError("The 'learner.fname's are not all unique.")
344+
345+
for l in self.learners:
346+
l.save(os.path.join(folder, l.fname), compress=compress)
347+
348+
def load(self, folder, compress=True):
349+
"""Load the data of the child learners from pickle files
350+
in a directory.
351+
352+
Parameters
353+
----------
354+
folder : str
355+
Directory from which the learners's data will be loaded.
356+
compress : bool, default True
357+
If the data is compressed when saved, one must load it
358+
with compression too.
359+
360+
Notes
361+
-----
362+
The child learners need to have a 'fname' attribute in order to use
363+
this method.
364+
365+
Example
366+
-------
367+
See the example in the 'BalancingLearner.save' doc-string.
368+
"""
369+
for l in self.learners:
370+
l.load(os.path.join(folder, l.fname), compress=compress)
371+
372+
def _get_data(self):
373+
return [l._get_data() for l in learner.learners]
374+
375+
def _set_data(self, data):
376+
for l, _data in zip(self.learners, data):
377+
l._set_data(_data)

adaptive/learner/base_learner.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# -*- coding: utf-8 -*-
22
import abc
3+
from contextlib import suppress
34
from copy import deepcopy
45

6+
from ..utils import save, load
7+
58

69
class BaseLearner(metaclass=abc.ABCMeta):
710
"""Base class for algorithms for learning a function 'f: X → Y'.
@@ -83,8 +86,84 @@ def ask(self, n, tell_pending=True):
8386
"""
8487
pass
8588

89+
@abc.abstractmethod
90+
def _get_data(self):
91+
pass
92+
93+
@abc.abstractmethod
94+
def _set_data(self):
95+
pass
96+
97+
def copy_from(self, other):
98+
"""Copy over the data from another learner.
99+
100+
Parameters
101+
----------
102+
other : BaseLearner object
103+
The learner from which the data is copied.
104+
"""
105+
self._set_data(other._get_data())
106+
107+
def save(self, fname=None, compress=True):
108+
"""Save the data of the learner into a pickle file.
109+
110+
Parameters
111+
----------
112+
fname : str, optional
113+
The filename of the learner's pickle data file. If None use
114+
the 'fname' attribute, like 'learner.fname = "example.p".
115+
compress : bool, default True
116+
Compress the data upon saving using 'gzip'. When saving
117+
using compression, one must load it with compression too.
118+
119+
Notes
120+
-----
121+
There are __two ways__ of naming the files:
122+
1. Using the 'fname' argument in 'learner.save(fname='example.p')
123+
2. Setting the 'fname' attribute, like
124+
'learner.fname = "data/example.p"' and then 'learner.save()'.
125+
"""
126+
fname = fname or self.fname
127+
data = self._get_data()
128+
save(fname, data, compress)
129+
130+
def load(self, fname=None, compress=True):
131+
"""Load the data of a learner from a pickle file.
132+
133+
Parameters
134+
----------
135+
fname : str, optional
136+
The filename of the saved learner's pickled data file.
137+
If None use the 'fname' attribute, like
138+
'learner.fname = "example.p".
139+
compress : bool, default True
140+
If the data is compressed when saved, one must load it
141+
with compression too.
142+
143+
Notes
144+
-----
145+
See the notes in the 'BaseLearner.save' doc-string.
146+
"""
147+
fname = fname or self.fname
148+
with suppress(FileNotFoundError, EOFError):
149+
data = load(fname, compress)
150+
self._set_data(data)
151+
86152
def __getstate__(self):
87153
return deepcopy(self.__dict__)
88154

89155
def __setstate__(self, state):
90156
self.__dict__ = state
157+
158+
@property
159+
def fname(self):
160+
# This is a property because then it will be availible in the DataSaver
161+
try:
162+
return self._fname
163+
except AttributeError:
164+
raise AttributeError("Set 'learner.fname' or use the 'fname'"
165+
" argument when using 'learner.save' or 'learner.load'.")
166+
167+
@fname.setter
168+
def fname(self, fname):
169+
self._fname = fname

adaptive/learner/data_saver.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from collections import OrderedDict
33
import functools
44

5+
from .base_learner import BaseLearner
6+
from ..utils import copy_docstring_from
7+
58

69
class DataSaver:
710
"""Save extra data associated with the values that need to be learned.
@@ -40,6 +43,25 @@ def tell(self, x, result):
4043
def tell_pending(self, x):
4144
self.learner.tell_pending(x)
4245

46+
def _get_data(self):
47+
return self.learner._get_data(), self.extra_data
48+
49+
def _set_data(self, data):
50+
learner_data, self.extra_data = data
51+
self.learner._set_data(learner_data)
52+
53+
@copy_docstring_from(BaseLearner.save)
54+
def save(self, fname=None, compress=True):
55+
# We copy this method because the 'DataSaver' is not a
56+
# subclass of the 'BaseLearner'.
57+
BaseLearner.save(self, fname, compress)
58+
59+
@copy_docstring_from(BaseLearner.load)
60+
def load(self, fname=None, compress=True):
61+
# We copy this method because the 'DataSaver' is not a
62+
# subclass of the 'BaseLearner'.
63+
BaseLearner.load(self, fname, compress)
64+
4365

4466
def _ds(learner_type, arg_picker, *args, **kwargs):
4567
args = args[2:] # functools.partial passes the first 2 arguments in 'args'!

adaptive/learner/integrator_learner.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,30 @@ def plot(self):
525525
xs, ys = zip(*[(x, y) for ival in ivals
526526
for x, y in sorted(ival.done_points.items())])
527527
return hv.Path((xs, ys))
528+
529+
def _get_data(self):
530+
# Change the defaultdict of SortedSets to a normal dict of sets.
531+
x_mapping = {k: set(v) for k, v in self.x_mapping.items()}
532+
533+
return (self.priority_split,
534+
self.done_points,
535+
self.pending_points,
536+
self._stack,
537+
x_mapping,
538+
self.ivals,
539+
self.first_ival)
540+
541+
def _set_data(self, data):
542+
self.priority_split, self.done_points, self.pending_points, \
543+
self._stack, x_mapping, self.ivals, self.first_ival = data
544+
545+
# Add the pending_points to the _stack such that they are evaluated again
546+
for x in self.pending_points:
547+
if x not in self._stack:
548+
self._stack.append(x)
549+
550+
# x_mapping is a data structure that can't easily be saved
551+
# so we recreate it here
552+
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter('rdepth')))
553+
for k, _set in x_mapping.items():
554+
self.x_mapping[k].update(_set)

adaptive/learner/learner1D.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,9 @@ def remove_unfinished(self):
485485
self.pending_points = set()
486486
self.losses_combined = deepcopy(self.losses)
487487
self.neighbors_combined = deepcopy(self.neighbors)
488+
489+
def _get_data(self):
490+
return self.data
491+
492+
def _set_data(self, data):
493+
self.tell_many(*zip(*data.items()))

adaptive/learner/learner2D.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from collections import OrderedDict
3+
from copy import copy
34
import itertools
45
from math import sqrt
56

@@ -522,3 +523,13 @@ def plot(self, n=None, tri_alpha=0):
522523
no_hover = dict(plot=dict(inspection_policy=None, tools=[]))
523524

524525
return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
526+
527+
def _get_data(self):
528+
return self.data
529+
530+
def _set_data(self, data):
531+
self.data = data
532+
# Remove points from stack if they already exist
533+
for point in copy(self._stack):
534+
if point in self.data:
535+
self._stack.pop(point)

adaptive/learner/learnerND.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,9 @@ def plot_slice(self, cut_mapping, n=None):
572572
return im.opts(style=dict(cmap='viridis'))
573573
else:
574574
raise ValueError("Only 1 or 2-dimensional plots can be generated.")
575+
576+
def _get_data(self):
577+
return self.data
578+
579+
def _set_data(self, data):
580+
self.tell_many(*zip(*data.items()))

adaptive/runner.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def goal(_):
449449
self.function)
450450

451451
self.task = self.ioloop.create_task(self._run())
452+
self.saving_task = None
452453
if in_ipynb() and not self.ioloop.is_running():
453454
warnings.warn("The runner has been scheduled, but the asyncio "
454455
"event loop is not running! If you are "
@@ -541,6 +542,31 @@ def elapsed_time(self):
541542
end_time = time.time()
542543
return end_time - self.start_time
543544

545+
def start_periodic_saving(self, save_kwargs, interval):
546+
"""Periodically save the learner's data.
547+
548+
Parameters
549+
----------
550+
save_kwargs : dict
551+
Key-word arguments for 'learner.save(**save_kwargs)'.
552+
interval : int
553+
Number of seconds between saving the learner.
554+
555+
Example
556+
-------
557+
>>> runner = Runner(learner)
558+
>>> runner.start_periodic_saving(
559+
... save_kwargs=dict(fname='data/test.pickle'),
560+
... interval=600)
561+
"""
562+
async def _saver(save_kwargs=save_kwargs, interval=interval):
563+
while self.status() == 'running':
564+
self.learner.save(**save_kwargs)
565+
await asyncio.sleep(interval)
566+
self.learner.save(**save_kwargs) # one last time
567+
self.saving_task = self.ioloop.create_task(_saver())
568+
return self.saving_task
569+
544570

545571
# Default runner
546572
Runner = AsyncRunner

0 commit comments

Comments
 (0)