Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit f5365be

Browse files
committed
Added depreciation warnings
1 parent 1386bae commit f5365be

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

python/spark_sklearn/grid_search.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import product
88
from collections import Sized, Mapping, namedtuple, defaultdict, Sequence
99
from functools import partial
10+
import warnings
1011
import numpy as np
1112
from scipy.stats import rankdata
1213

@@ -201,8 +202,22 @@ def fit_old(self, X, y=None):
201202

202203
#ef _fit(self, X, y, parameter_iterable, groups=None):
203204
def fit(self, X, y=None, groups=None, **fit_params):
205+
206+
if self.fit_params is not None:
207+
warnings.warn('"fit_params" as a constructor argument was '
208+
'deprecated in version 0.19 and will be removed '
209+
'in version 0.21. Pass fit parameters to the '
210+
'"fit" method instead.', DeprecationWarning)
211+
if fit_params:
212+
warnings.warn('Ignoring fit_params passed as a constructor '
213+
'argument in favor of keyword arguments to '
214+
'the "fit" method.', RuntimeWarning)
215+
else:
216+
fit_params = self.fit_params
217+
204218
estimator = self.estimator
205219
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
220+
206221
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
207222

208223
X, y, groups = indexable(X, y, groups)
@@ -212,10 +227,10 @@ def fit(self, X, y=None, groups=None, **fit_params):
212227
#candidate_params = parameter_iterable # change later
213228
candidate_params = ParameterGrid(self.param_grid)
214229
n_candidates = len(candidate_params)
215-
# if self.verbose > 0:
216-
# print("Fitting {0} folds for each of {1} candidates, totalling"
217-
# " {2} fits".format(n_splits, n_candidates,
218-
# n_candidates * n_splits))
230+
if self.verbose > 0:
231+
print("Fitting {0} folds for each of {1} candidates, totalling"
232+
" {2} fits".format(n_splits, n_candidates,
233+
n_candidates * n_splits))
219234

220235
base_estimator = clone(self.estimator)
221236

0 commit comments

Comments
 (0)