Skip to content

Commit 329f447

Browse files
committed
add an option to use a list of filenames when saving a BalancingLearner
1 parent fcdb053 commit 329f447

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
from collections import defaultdict
2+
from collections import defaultdict, Iterable
33
from contextlib import suppress
44
from functools import partial
55
from operator import itemgetter
@@ -323,8 +323,9 @@ def save(self, fname, compress=True):
323323
324324
Parameters
325325
----------
326-
fname: callable
327-
Given a learner, returns a filename into which to save the data
326+
fname: callable or sequence of strings
327+
Given a learner, returns a filename into which to save the data.
328+
Or a list (or iterable) with filenames.
328329
compress : bool, default True
329330
Compress the data upon saving using `gzip`. When saving
330331
using compression, one must load it with compression too.
@@ -347,17 +348,22 @@ def save(self, fname, compress=True):
347348
>>> # Then save
348349
>>> learner.save(combo_fname) # use 'load' in the same way
349350
"""
350-
for l in self.learners:
351-
l.save(fname(l), compress=compress)
351+
if isinstance(fname, Iterable):
352+
for l, _fname in zip(fname, self.learners):
353+
l.save(_fname, compress=compress)
354+
else:
355+
for l in self.learners:
356+
l.save(fname(l), compress=compress)
352357

353358
def load(self, fname, compress=True):
354359
"""Load the data of the child learners from pickle files
355360
in a directory.
356361
357362
Parameters
358363
----------
359-
fname: callable
360-
Given a learner, returns a filename into which to save the data
364+
fname: callable or sequence of strings
365+
Given a learner, returns a filename from which to load the data.
366+
Or a list (or iterable) with filenames.
361367
compress : bool, default True
362368
If the data is compressed when saved, one must load it
363369
with compression too.
@@ -366,8 +372,12 @@ def load(self, fname, compress=True):
366372
-------
367373
See the example in the `BalancingLearner.save` doc-string.
368374
"""
369-
for l in self.learners:
370-
l.load(fname(l), compress=compress)
375+
if isinstance(fname, Iterable):
376+
for l, _fname in zip(fname, self.learners):
377+
l.load(_fname, compress=compress)
378+
else:
379+
for l in self.learners:
380+
l.load(fname(l), compress=compress)
371381

372382
def _get_data(self):
373383
return [l._get_data() for l in learner.learners]

docs/source/tutorial/tutorial.advanced-topics.rst

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,10 @@ Saving and loading learners
3333
Every learner has a `~adaptive.BaseLearner.save` and `~adaptive.BaseLearner.load`
3434
method that can be used to save and load **only** the data of a learner.
3535

36-
There are **two ways** of naming the files: 1. Using the ``fname``
37-
argument in ``learner.save(fname=...)`` 2. Setting the ``fname``
38-
attribute, like ``learner.fname = 'data/example.p`` and then
39-
``learner.save()``
36+
Use the ``fname`` argument in ``learner.save(fname=...)``.
4037

41-
The second way *must be used* when saving the ``learner``\s of a
42-
`~adaptive.BalancingLearner`.
38+
Or, when using a `~adaptive.BalancingLearner` one can use either a callable
39+
that takes the child learner and returns a filename **or** a list of filenames.
4340

4441
By default the resulting pickle files are compressed, to turn this off
4542
use ``learner.save(fname=..., compress=False)``

0 commit comments

Comments
 (0)