1
1
# -*- coding: utf-8 -*-
2
- from collections import defaultdict
2
+ from collections import defaultdict , Iterable
3
3
from contextlib import suppress
4
4
from functools import partial
5
5
from operator import itemgetter
@@ -323,8 +323,9 @@ def save(self, fname, compress=True):
323
323
324
324
Parameters
325
325
----------
326
- fname: callable
327
- Given a learner, returns a filename into which to save the data
326
+ fname: callable or sequence of strings
327
+ Given a learner, returns a filename into which to save the data.
328
+ Or a list (or iterable) with filenames.
328
329
compress : bool, default True
329
330
Compress the data upon saving using `gzip`. When saving
330
331
using compression, one must load it with compression too.
@@ -347,17 +348,22 @@ def save(self, fname, compress=True):
347
348
>>> # Then save
348
349
>>> learner.save(combo_fname) # use 'load' in the same way
349
350
"""
350
- for l in self .learners :
351
- l .save (fname (l ), compress = compress )
351
+ if isinstance (fname , Iterable ):
352
+ for l , _fname in zip (fname , self .learners ):
353
+ l .save (_fname , compress = compress )
354
+ else :
355
+ for l in self .learners :
356
+ l .save (fname (l ), compress = compress )
352
357
353
358
def load (self , fname , compress = True ):
354
359
"""Load the data of the child learners from pickle files
355
360
in a directory.
356
361
357
362
Parameters
358
363
----------
359
- fname: callable
360
- Given a learner, returns a filename into which to save the data
364
+ fname: callable or sequence of strings
365
+ Given a learner, returns a filename from which to load the data.
366
+ Or a list (or iterable) with filenames.
361
367
compress : bool, default True
362
368
If the data is compressed when saved, one must load it
363
369
with compression too.
@@ -366,8 +372,12 @@ def load(self, fname, compress=True):
366
372
-------
367
373
See the example in the `BalancingLearner.save` doc-string.
368
374
"""
369
- for l in self .learners :
370
- l .load (fname (l ), compress = compress )
375
+ if isinstance (fname , Iterable ):
376
+ for l , _fname in zip (fname , self .learners ):
377
+ l .load (_fname , compress = compress )
378
+ else :
379
+ for l in self .learners :
380
+ l .load (fname (l ), compress = compress )
371
381
372
382
def _get_data (self ):
373
383
return [l ._get_data () for l in learner .learners ]
0 commit comments