1+ # -*- encoding: utf-8 -*-
2+ from __future__ import print_function
3+ import unittest
4+ import mock
5+ from autosklearn .automl import AutoML
6+ from autosklearn .util .backend import Backend
7+
8+
9+ class AutoMLStub (object ):
10+
11+ def __init__ (self ):
12+ self .__class__ = AutoML
13+
14+
15+ class AutoMlModelsTest (unittest .TestCase ):
16+
17+ def setUp (self ):
18+ self .automl = AutoMLStub ()
19+ self .automl ._shared_mode = False
20+ self .automl ._seed = 42
21+ self .automl ._backend = mock .Mock (spec = Backend )
22+ self .automl ._delete_output_directories = lambda : 0
23+
24+ def test_only_loads_ensemble_models (self ):
25+ identifiers = [(1 , 2 ), (3 , 4 )]
26+ models = [ 42 ]
27+ self .automl ._backend .load_ensemble .return_value .identifiers_ \
28+ = identifiers
29+ self .automl ._backend .load_models_by_identifiers .side_effect \
30+ = lambda ids : models if ids is identifiers else None
31+
32+ self .automl ._load_models ()
33+
34+ self .assertEqual (models , self .automl .models_ )
35+
36+ def test_loads_all_models_if_no_ensemble (self ):
37+ models = [ 42 ]
38+ self .automl ._backend .load_ensemble .return_value = None
39+ self .automl ._backend .load_all_models .return_value = models
40+
41+ self .automl ._load_models ()
42+
43+ self .assertEqual (models , self .automl .models_ )
44+
45+ def test_raises_if_no_models (self ):
46+ self .automl ._backend .load_ensemble .return_value = None
47+ self .automl ._backend .load_all_models .return_value = []
48+
49+ self .assertRaises (ValueError , self .automl ._load_models )
0 commit comments