Skip to content

Commit a67832a

Browse files
Fix dataprocessing get params (#877)
* Fix dataprocessing get params * Add clone-test to regression pipeline
1 parent 47a3f12 commit a67832a

File tree

4 files changed

+60
-2
lines changed

4 files changed

+60
-2
lines changed

autosklearn/pipeline/components/data_preprocessing/balancing/balancing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
class Balancing(AutoSklearnPreprocessingAlgorithm):
1212
def __init__(self, strategy='none', random_state=None):
1313
self.strategy = strategy
14+
self.random_state = random_state
1415

1516
def fit(self, X, y=None):
1617
return self

autosklearn/pipeline/components/data_preprocessing/data_preprocessing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,15 @@ def __init__(self, config=None, pipeline=None, dataset_properties=None, include=
3232
if categorical_features.dtype != 'bool':
3333
raise ValueError('Parameter categorical_features must'
3434
' only contain booleans.')
35+
self.config = config
36+
self.pipeline = pipeline
37+
self.dataset_properties = dataset_properties
38+
self.include = include
39+
self.exclude = exclude
40+
self.random_state = random_state
41+
self.init_params = init_params
3542
self.categorical_features = categorical_features
43+
self.force_sparse_output = force_sparse_output
3644

3745
# The pipeline that will be applied to the categorical features (i.e. columns)
3846
# of the dataset
@@ -48,7 +56,6 @@ def __init__(self, config=None, pipeline=None, dataset_properties=None, include=
4856
["categorical_transformer", self.categ_ppl],
4957
["numerical_transformer", self.numer_ppl],
5058
]
51-
self.force_sparse = force_sparse_output
5259

5360
def fit(self, X, y=None):
5461
n_feats = X.shape[1]
@@ -73,7 +80,7 @@ def fit(self, X, y=None):
7380
["numerical_transformer", self.numer_ppl, num_feats]
7481
]
7582

76-
self.sparse_ = sparse.issparse(X) or self.force_sparse
83+
self.sparse_ = sparse.issparse(X) or self.force_sparse_output
7784
self.column_transformer = sklearn.compose.ColumnTransformer(
7885
transformers=sklearn_transf_spec,
7986
sparse_threshold=float(self.sparse_),

test/test_pipeline/test_classification.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from joblib import Memory
1010
import numpy as np
1111

12+
from sklearn.base import clone
1213
import sklearn.datasets
1314
import sklearn.decomposition
1415
import sklearn.model_selection
@@ -578,6 +579,30 @@ def test_predict_proba_batched_sparse(self):
578579
self.assertEqual(84, cls_predict.call_count)
579580
np.testing.assert_array_almost_equal(prediction_, prediction)
580581

582+
def test_pipeline_clonability(self):
583+
X_train, Y_train, X_test, Y_test = get_dataset(dataset='iris')
584+
auto = SimpleClassificationPipeline()
585+
auto = auto.fit(X_train, Y_train)
586+
auto_clone = clone(auto)
587+
auto_clone_params = auto_clone.get_params()
588+
589+
# Make sure all keys are copied properly
590+
for k, v in auto.get_params().items():
591+
self.assertIn(k, auto_clone_params)
592+
593+
# Make sure the params getter of estimator are honored
594+
klass = auto.__class__
595+
new_object_params = auto.get_params(deep=False)
596+
for name, param in new_object_params.items():
597+
new_object_params[name] = clone(param, safe=False)
598+
new_object = klass(**new_object_params)
599+
params_set = new_object.get_params(deep=False)
600+
601+
for name in new_object_params:
602+
param1 = new_object_params[name]
603+
param2 = params_set[name]
604+
self.assertEqual(param1, param2)
605+
581606
def test_set_params(self):
582607
pass
583608

test/test_pipeline/test_regression.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import sklearn.datasets
1010
import sklearn.decomposition
11+
from sklearn.base import clone
1112
import sklearn.ensemble
1213
import sklearn.svm
1314

@@ -370,6 +371,30 @@ def test_validate_input_X(self):
370371
def test_validate_input_Y(self):
371372
raise NotImplementedError()
372373

374+
def test_pipeline_clonability(self):
375+
X_train, Y_train, X_test, Y_test = get_dataset(dataset='boston')
376+
auto = SimpleRegressionPipeline()
377+
auto = auto.fit(X_train, Y_train)
378+
auto_clone = clone(auto)
379+
auto_clone_params = auto_clone.get_params()
380+
381+
# Make sure all keys are copied properly
382+
for k, v in auto.get_params().items():
383+
self.assertIn(k, auto_clone_params)
384+
385+
# Make sure the params getter of estimator are honored
386+
klass = auto.__class__
387+
new_object_params = auto.get_params(deep=False)
388+
for name, param in new_object_params.items():
389+
new_object_params[name] = clone(param, safe=False)
390+
new_object = klass(**new_object_params)
391+
params_set = new_object.get_params(deep=False)
392+
393+
for name in new_object_params:
394+
param1 = new_object_params[name]
395+
param2 = params_set[name]
396+
self.assertEqual(param1, param2)
397+
373398
def test_set_params(self):
374399
pass
375400

0 commit comments

Comments
 (0)