Skip to content

Commit 0700ce9

Browse files
Test new backend methods
1 parent 85aaa8d commit 0700ce9

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

test/util/test_backend.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# -*- encoding: utf-8 -*-
2+
from __future__ import print_function
3+
import unittest
4+
import mock
5+
from autosklearn.util.backend import Backend
6+
7+
8+
class BackendModelsTest(unittest.TestCase):
9+
10+
class BackendStub(Backend):
11+
12+
def __init__(self, model_directory):
13+
self.__class__ = Backend
14+
self.get_model_dir = lambda: model_directory
15+
16+
def setUp(self):
17+
self.model_directory = '/model_directory/'
18+
self.backend = self.BackendStub(self.model_directory)
19+
self.BackendStub.get_model_dir = lambda x: 42
20+
21+
@mock.patch('six.moves.cPickle.load')
22+
@mock.patch('__builtin__.open')
23+
def test_loads_model_by_seed_and_id(self, openMock, pickleLoadMock):
24+
seed = 13
25+
idx = 17
26+
expected_model = self._setup_load_model_mocks(openMock, pickleLoadMock, seed, idx)
27+
28+
actual_model = self.backend.load_model_by_seed_and_id(seed, idx)
29+
30+
self.assertEqual(expected_model, actual_model)
31+
32+
@mock.patch('six.moves.cPickle.load')
33+
@mock.patch('__builtin__.open')
34+
def test_loads_models_by_identifiers(self, openMock, pickleLoadMock):
35+
seed = 13
36+
idx = 17
37+
expected_model = self._setup_load_model_mocks(openMock, pickleLoadMock, seed, idx)
38+
expected_dict = { (seed, idx): expected_model }
39+
40+
actual_dict = self.backend.load_models_by_identifiers([(seed, idx)])
41+
42+
self.assertIsInstance(actual_dict, dict)
43+
self.assertDictEqual(expected_dict, actual_dict)
44+
45+
def _setup_load_model_mocks(self, openMock, pickleLoadMock, seed, idx):
46+
model_path = '/model_directory/%s.%s.model' % (seed, idx)
47+
file_handler = 'file_handler'
48+
expected_model = 'model'
49+
50+
fileMock = mock.MagicMock()
51+
fileMock.__enter__.return_value = file_handler
52+
53+
openMock.side_effect = lambda path, flag: fileMock if path == model_path and flag == 'rb' else None
54+
pickleLoadMock.side_effect = lambda fh: expected_model if fh == file_handler else None
55+
56+
return expected_model

0 commit comments

Comments
 (0)