Skip to content

Commit 7fa7708

Browse files
update utils
1 parent 16e148d commit 7fa7708

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

adapt/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def check_estimator(estimator=None, copy=True,
124124
else:
125125
estimator = LinearRegression()
126126

127+
# TODO, add KerasWrappers in doc and error message
127128
if isinstance(estimator, (BaseEstimator, KerasClassifier, KerasRegressor)):
128129
if (isinstance(estimator, ClassifierMixin) and task=="reg"):
129130
raise ValueError("`%s` argument is a sklearn `ClassifierMixin` instance "
@@ -137,7 +138,11 @@ def check_estimator(estimator=None, copy=True,
137138
"tensorflow Model instance."%display_name)
138139
if copy:
139140
try:
140-
new_estimator = deepcopy(estimator)
141+
if isinstance(estimator, (KerasClassifier, KerasRegressor)):
142+
# TODO, copy fitted parameters and Model
143+
new_estimator = clone(estimator)
144+
else:
145+
new_estimator = deepcopy(estimator)
141146
except Exception as e:
142147
if force_copy:
143148
raise ValueError("`%s` argument can't be duplicated. "
@@ -151,7 +156,7 @@ def check_estimator(estimator=None, copy=True,
151156
(display_name, e))
152157
new_estimator = estimator
153158
else:
154-
new_estimator = estimator
159+
new_estimator = estimator
155160
elif isinstance(estimator, Model):
156161
new_estimator = check_network(network=estimator,
157162
copy=copy,

tests/test_utils.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from sklearn.multioutput import MultiOutputRegressor
1616
from sklearn.compose import TransformedTargetRegressor
1717
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
18-
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
18+
from sklearn.tree._tree import Tree
19+
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
1920
from tensorflow.keras import Model, Sequential
2021
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape
2122
from tensorflow.python.keras.engine.input_layer import InputLayer
@@ -31,18 +32,19 @@ def is_equal_estimator(v1, v2):
3132
if isinstance(v1, np.ndarray):
3233
if not np.array_equal(v1, v2):
3334
is_equal = False
34-
elif isinstance(v1, BaseEstimator):
35+
elif isinstance(v1, (BaseEstimator, KerasClassifier, KerasRegressor)):
3536
if not is_equal_estimator(v1.__dict__, v2.__dict__):
3637
is_equal = False
3738
elif isinstance(v1, Model):
3839
if not is_equal_estimator(v1.get_config(),
3940
v2.get_config()):
4041
is_equal = False
4142
elif isinstance(v1, dict):
43+
if set(v1.keys()) != set(v2.keys()):
44+
is_equal = False
4245
for k1_i, v1_i in v1.items():
43-
if k1_i not in v2:
44-
is_equal = False
45-
else:
46+
# Avoid exception due to new input layer name
47+
if k1_i != "name":
4648
v2_i = v2[k1_i]
4749
if not is_equal_estimator(v1_i, v2_i):
4850
is_equal = False
@@ -53,6 +55,8 @@ def is_equal_estimator(v1, v2):
5355
for v1_i, v2_i in zip(v1, v2):
5456
if not is_equal_estimator(v1_i, v2_i):
5557
is_equal = False
58+
elif isinstance(v1, Tree):
59+
pass # TODO create a function to check if two tree are equal
5660
else:
5761
if not v1 == v2:
5862
is_equal = False
@@ -201,9 +205,9 @@ def test_check_network_network(net):
201205
@pytest.mark.parametrize("net", networks)
202206
def test_check_network_copy(net):
203207
new_net = check_network(net, copy=True, compile_=False)
204-
hex(id(new_net)) != hex(id(net))
208+
assert hex(id(new_net)) != hex(id(net))
205209
new_net = check_network(net, copy=False, compile_=False)
206-
hex(id(new_net)) == hex(id(net))
210+
assert hex(id(new_net)) == hex(id(net))
207211

208212

209213
no_networks = ["lala", Ridge(), 123, np.ones((10, 10))]
@@ -272,16 +276,24 @@ def test_check_network_compile():
272276
TransformedTargetRegressor(Ridge(alpha=25), StandardScaler()),
273277
MultiOutputRegressor(Ridge(alpha=0.3)),
274278
make_pipeline(StandardScaler(), Ridge(alpha=0.2)),
275-
KerasClassifier(_get_model_Sequential(input_shape=(1,))),
279+
KerasClassifier(_get_model_Sequential, input_shape=(1,)),
276280
CustomEstimator()
277281
]
278282

279283
@pytest.mark.parametrize("est", estimators)
280284
def test_check_estimator_estimators(est):
281-
new_est = check_estimator(est)
285+
new_est = check_estimator(est, copy=True, force_copy=True)
282286
assert is_equal_estimator(est, new_est)
283-
est.fit(np.linspace(0, 1, 10).reshape(-1, 1),
284-
(np.linspace(0, 1, 10)<0.5).astype(float))
287+
if isinstance(est, MultiOutputRegressor):
288+
est.fit(np.linspace(0, 1, 10).reshape(-1, 1),
289+
np.stack([np.linspace(0, 1, 10)<0.5]*2, -1).astype(float))
290+
else:
291+
est.fit(np.linspace(0, 1, 10).reshape(-1, 1),
292+
(np.linspace(0, 1, 10)<0.5).astype(float))
293+
if isinstance(est, KerasClassifier):
294+
new_est = check_estimator(est, copy=False)
295+
else:
296+
new_est = check_estimator(est, copy=True, force_copy=True)
285297
assert is_equal_estimator(est, new_est)
286298

287299

@@ -294,7 +306,7 @@ def test_check_estimator_networks(est):
294306
no_estimators = ["lala", 123, np.ones((10, 10))]
295307

296308
@pytest.mark.parametrize("no_est", no_estimators)
297-
def test_check_estimator_estimators(no_est):
309+
def test_check_estimator_no_estimators(no_est):
298310
with pytest.raises(ValueError) as excinfo:
299311
new_est = check_estimator(no_est)
300312
assert ("`estimator` argument is neither a sklearn `BaseEstimator` "
@@ -310,9 +322,9 @@ def test_check_estimator_estimators(no_est):
310322
@pytest.mark.parametrize("est", estimators)
311323
def test_check_estimator_copy(est):
312324
new_est = check_estimator(est, copy=True)
313-
hex(id(new_est)) != hex(id(est))
325+
assert hex(id(new_est)) != hex(id(est))
314326
new_est = check_estimator(est, copy=False)
315-
hex(id(new_est)) == hex(id(est))
327+
assert hex(id(new_est)) == hex(id(est))
316328

317329

318330
def test_check_estimator_force_copy():

0 commit comments

Comments
 (0)