Skip to content

Commit dce7b10

Browse files
committed
Merge branch 'selective-model-load' of https://github.com/Ayaro/auto-sklearn into Ayaro-selective-model-load
2 parents de56d42 + a34d741 commit dce7b10

File tree

4 files changed

+147
-7
lines changed

4 files changed

+147
-7
lines changed

autosklearn/automl.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,11 +639,16 @@ def _load_models(self):
639639
else:
640640
seed = self._seed
641641

642-
self.models_ = self._backend.load_all_models(seed)
642+
self.ensemble_ = self._backend.load_ensemble(seed)
643+
if self.ensemble_:
644+
identifiers = self.ensemble_.identifiers_
645+
self.models_ = self._backend.load_models_by_identifiers(identifiers)
646+
else:
647+
self.models_ = self._backend.load_all_models(seed)
648+
643649
if len(self.models_) == 0:
644650
raise ValueError('No models fitted!')
645651

646-
self.ensemble_ = self._backend.load_ensemble(seed)
647652

648653
def score(self, X, y):
649654
# fix: Consider only index 1 of second dimension

autosklearn/util/backend.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,44 @@ def load_all_models(self, seed):
168168
model_files = os.listdir(model_directory)
169169
model_files = [os.path.join(model_directory, mf) for mf in model_files]
170170

171+
models = self.load_models_by_file_names(model_files)
172+
173+
return models
174+
175+
def load_models_by_file_names(self, model_file_names):
171176
models = dict()
172-
for model_file in model_files:
177+
178+
for model_file in model_file_names:
173179
# File names are like: {seed}.{index}.model
174180
if model_file.endswith('/'):
175181
model_file = model_file[:-1]
176182
basename = os.path.basename(model_file)
177-
automl_seed = int(basename.split('.')[0])
178-
idx = int(basename.split('.')[1])
179-
with open(os.path.join(model_directory, basename), 'rb') as fh:
180-
models[(automl_seed, idx)] = (pickle.load(fh))
183+
184+
basename_parts = basename.split('.')
185+
seed = int(basename_parts[0])
186+
idx = int(basename_parts[1])
187+
188+
models[(seed, idx)] = self.load_model_by_seed_and_id(seed, idx)
189+
190+
return models
191+
192+
def load_models_by_identifiers(self, identifiers):
193+
models = dict()
194+
195+
for identifier in identifiers:
196+
seed, idx = identifier
197+
models[identifier] = self.load_model_by_seed_and_id(seed, idx)
181198

182199
return models
183200

201+
def load_model_by_seed_and_id(self, seed, idx):
202+
model_directory = self.get_model_dir()
203+
model_file_name = '%s.%s.model' % (seed, idx)
204+
model_file_path = os.path.join(model_directory, model_file_name)
205+
206+
with open(model_file_path, 'rb') as fh:
207+
return (pickle.load(fh))
208+
184209
def get_ensemble_dir(self):
185210
return os.path.join(self.internals_directory, 'ensembles')
186211

test/automl/test_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- encoding: utf-8 -*-
2+
from __future__ import print_function
3+
import unittest
4+
import mock
5+
from autosklearn.automl import AutoML
6+
from autosklearn.util.backend import Backend
7+
8+
9+
class AutoMLStub(AutoML):
10+
11+
def __init__(self):
12+
self.__class__ = AutoML
13+
14+
15+
class AutoMlModelsTest(unittest.TestCase):
16+
17+
def setUp(self):
18+
self.automl = AutoMLStub()
19+
self.automl._shared_mode = False
20+
self.automl._seed = 42
21+
self.automl._backend = mock.Mock(spec=Backend)
22+
self.automl._delete_output_directories = lambda: 0
23+
24+
def test_only_loads_ensemble_models(self):
25+
identifiers = [(1, 2), (3, 4)]
26+
models = [ 42 ]
27+
self.automl._backend.load_ensemble.return_value.identifiers_ \
28+
= identifiers
29+
self.automl._backend.load_models_by_identifiers.side_effect \
30+
= lambda ids: models if ids is identifiers else None
31+
32+
self.automl._load_models()
33+
34+
self.assertEqual(models, self.automl.models_)
35+
36+
def test_loads_all_models_if_no_ensemble(self):
37+
models = [ 42 ]
38+
self.automl._backend.load_ensemble.return_value = None
39+
self.automl._backend.load_all_models.return_value = models
40+
41+
self.automl._load_models()
42+
43+
self.assertEqual(models, self.automl.models_)
44+
45+
def test_raises_if_no_models(self):
46+
self.automl._backend.load_ensemble.return_value = None
47+
self.automl._backend.load_all_models.return_value = []
48+
49+
self.assertRaises(ValueError, self.automl._load_models)

test/util/test_backend.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -*- encoding: utf-8 -*-
2+
from __future__ import print_function
3+
import unittest
4+
import mock
5+
from autosklearn.util.backend import Backend
6+
7+
from sys import version_info
8+
if version_info.major == 2:
9+
import __builtin__ as builtins
10+
else:
11+
import builtins
12+
13+
14+
class BackendModelsTest(unittest.TestCase):
15+
16+
class BackendStub(Backend):
17+
18+
def __init__(self):
19+
self.__class__ = Backend
20+
21+
def setUp(self):
22+
self.model_directory = '/model_directory/'
23+
self.backend = self.BackendStub()
24+
self.backend.get_model_dir = lambda: self.model_directory
25+
26+
@mock.patch('six.moves.cPickle.load')
27+
@mock.patch.object(builtins, 'open')
28+
def test_loads_model_by_seed_and_id(self, openMock, pickleLoadMock):
29+
seed = 13
30+
idx = 17
31+
expected_model = self._setup_load_model_mocks(openMock, pickleLoadMock, seed, idx)
32+
33+
actual_model = self.backend.load_model_by_seed_and_id(seed, idx)
34+
35+
self.assertEqual(expected_model, actual_model)
36+
37+
@mock.patch('six.moves.cPickle.load')
38+
@mock.patch.object(builtins, 'open')
39+
def test_loads_models_by_identifiers(self, openMock, pickleLoadMock):
40+
seed = 13
41+
idx = 17
42+
expected_model = self._setup_load_model_mocks(openMock, pickleLoadMock, seed, idx)
43+
expected_dict = { (seed, idx): expected_model }
44+
45+
actual_dict = self.backend.load_models_by_identifiers([(seed, idx)])
46+
47+
self.assertIsInstance(actual_dict, dict)
48+
self.assertDictEqual(expected_dict, actual_dict)
49+
50+
def _setup_load_model_mocks(self, openMock, pickleLoadMock, seed, idx):
51+
model_path = '/model_directory/%s.%s.model' % (seed, idx)
52+
file_handler = 'file_handler'
53+
expected_model = 'model'
54+
55+
fileMock = mock.MagicMock()
56+
fileMock.__enter__.return_value = file_handler
57+
58+
openMock.side_effect = lambda path, flag: fileMock if path == model_path and flag == 'rb' else None
59+
pickleLoadMock.side_effect = lambda fh: expected_model if fh == file_handler else None
60+
61+
return expected_model

0 commit comments

Comments
 (0)