1515from sklearn .multioutput import MultiOutputRegressor
1616from sklearn .compose import TransformedTargetRegressor
1717from sklearn .base import BaseEstimator , RegressorMixin , ClassifierMixin
18- from tensorflow .keras .wrappers .scikit_learn import KerasClassifier
18+ from sklearn .tree ._tree import Tree
19+ from tensorflow .keras .wrappers .scikit_learn import KerasClassifier , KerasRegressor
1920from tensorflow .keras import Model , Sequential
2021from tensorflow .keras .layers import Input , Dense , Flatten , Reshape
2122from tensorflow .python .keras .engine .input_layer import InputLayer
@@ -31,18 +32,19 @@ def is_equal_estimator(v1, v2):
3132 if isinstance (v1 , np .ndarray ):
3233 if not np .array_equal (v1 , v2 ):
3334 is_equal = False
34- elif isinstance (v1 , BaseEstimator ):
35+ elif isinstance (v1 , ( BaseEstimator , KerasClassifier , KerasRegressor ) ):
3536 if not is_equal_estimator (v1 .__dict__ , v2 .__dict__ ):
3637 is_equal = False
3738 elif isinstance (v1 , Model ):
3839 if not is_equal_estimator (v1 .get_config (),
3940 v2 .get_config ()):
4041 is_equal = False
4142 elif isinstance (v1 , dict ):
43+ if set (v1 .keys ()) != set (v2 .keys ()):
44+ is_equal = False
4245 for k1_i , v1_i in v1 .items ():
43- if k1_i not in v2 :
44- is_equal = False
45- else :
46+ # Avoid exception due to new input layer name
47+ if k1_i != "name" :
4648 v2_i = v2 [k1_i ]
4749 if not is_equal_estimator (v1_i , v2_i ):
4850 is_equal = False
@@ -53,6 +55,8 @@ def is_equal_estimator(v1, v2):
5355 for v1_i , v2_i in zip (v1 , v2 ):
5456 if not is_equal_estimator (v1_i , v2_i ):
5557 is_equal = False
58+ elif isinstance (v1 , Tree ):
59+ pass # TODO create a function to check if two tree are equal
5660 else :
5761 if not v1 == v2 :
5862 is_equal = False
@@ -201,9 +205,9 @@ def test_check_network_network(net):
201205@pytest .mark .parametrize ("net" , networks )
202206def test_check_network_copy (net ):
203207 new_net = check_network (net , copy = True , compile_ = False )
204- hex (id (new_net )) != hex (id (net ))
208+ assert hex (id (new_net )) != hex (id (net ))
205209 new_net = check_network (net , copy = False , compile_ = False )
206- hex (id (new_net )) == hex (id (net ))
210+ assert hex (id (new_net )) == hex (id (net ))
207211
208212
209213no_networks = ["lala" , Ridge (), 123 , np .ones ((10 , 10 ))]
@@ -272,16 +276,24 @@ def test_check_network_compile():
272276 TransformedTargetRegressor (Ridge (alpha = 25 ), StandardScaler ()),
273277 MultiOutputRegressor (Ridge (alpha = 0.3 )),
274278 make_pipeline (StandardScaler (), Ridge (alpha = 0.2 )),
275- KerasClassifier (_get_model_Sequential ( input_shape = (1 ,) )),
279+ KerasClassifier (_get_model_Sequential , input_shape = (1 ,)),
276280 CustomEstimator ()
277281]
278282
279283@pytest .mark .parametrize ("est" , estimators )
280284def test_check_estimator_estimators (est ):
281- new_est = check_estimator (est )
285+ new_est = check_estimator (est , copy = True , force_copy = True )
282286 assert is_equal_estimator (est , new_est )
283- est .fit (np .linspace (0 , 1 , 10 ).reshape (- 1 , 1 ),
284- (np .linspace (0 , 1 , 10 )< 0.5 ).astype (float ))
287+ if isinstance (est , MultiOutputRegressor ):
288+ est .fit (np .linspace (0 , 1 , 10 ).reshape (- 1 , 1 ),
289+ np .stack ([np .linspace (0 , 1 , 10 )< 0.5 ]* 2 , - 1 ).astype (float ))
290+ else :
291+ est .fit (np .linspace (0 , 1 , 10 ).reshape (- 1 , 1 ),
292+ (np .linspace (0 , 1 , 10 )< 0.5 ).astype (float ))
293+ if isinstance (est , KerasClassifier ):
294+ new_est = check_estimator (est , copy = False )
295+ else :
296+ new_est = check_estimator (est , copy = True , force_copy = True )
285297 assert is_equal_estimator (est , new_est )
286298
287299
@@ -294,7 +306,7 @@ def test_check_estimator_networks(est):
294306no_estimators = ["lala" , 123 , np .ones ((10 , 10 ))]
295307
296308@pytest .mark .parametrize ("no_est" , no_estimators )
297- def test_check_estimator_estimators (no_est ):
309+ def test_check_estimator_no_estimators (no_est ):
298310 with pytest .raises (ValueError ) as excinfo :
299311 new_est = check_estimator (no_est )
300312 assert ("`estimator` argument is neither a sklearn `BaseEstimator` "
@@ -310,9 +322,9 @@ def test_check_estimator_estimators(no_est):
310322@pytest .mark .parametrize ("est" , estimators )
311323def test_check_estimator_copy (est ):
312324 new_est = check_estimator (est , copy = True )
313- hex (id (new_est )) != hex (id (est ))
325+ assert hex (id (new_est )) != hex (id (est ))
314326 new_est = check_estimator (est , copy = False )
315- hex (id (new_est )) == hex (id (est ))
327+ assert hex (id (new_est )) == hex (id (est ))
316328
317329
318330def test_check_estimator_force_copy ():
0 commit comments