55import sys
66
77from itertools import product
8- from collections import Sized , Mapping , namedtuple , defaultdict , Sequence
8+ from collections import defaultdict
99from functools import partial
1010import warnings
1111
2323
2424
2525class GridSearchCV (BaseSearchCV ):
26- """Exhaustive search over specified parameter values for an estimator, using Spark to
27- distribute the computations.
28-
26+ """Exhaustive search over specified parameter values for an estimator.
27+
2928 Important members are fit, predict.
30-
31- GridSearchCV implements a "fit" method and a "predict" method like
32- any classifier except that the parameters of the classifier
33- used to predict is optimized by cross-validation.
34-
29+
30+ GridSearchCV implements a "fit" and a "score" method.
31+ It also implements "predict", "predict_proba", "decision_function",
32+ "transform" and "inverse_transform" if they are implemented in the
33+ estimator used.
34+
35+ The parameters of the estimator used to apply these methods are optimized
36+ by cross-validated grid-search over a parameter grid.
37+ Read more in the :ref:`User Guide <grid_search>`.
38+
3539 Parameters
3640 ----------
37- sc: the spark context
38-
39- estimator : object type that implements the "fit" and "predict" methods
40- A object of that type is instantiated for each grid point .
41-
41+ estimator : estimator object.
42+ This is assumed to implement the scikit-learn estimator interface.
43+ Either estimator needs to provide a ``score`` function,
44+ or ``scoring`` must be passed .
45+
4246 param_grid : dict or list of dictionaries
4347 Dictionary with parameters names (string) as keys and lists of
4448 parameter settings to try as values, or a list of such
4549 dictionaries, in which case the grids spanned by each dictionary
4650 in the list are explored. This enables searching over any sequence
4751 of parameter settings.
48-
49- scoring : string, callable or None, optional, default: None
52+
53+ scoring : string, callable or None, default= None
5054 A string (see model evaluation documentation) or
5155 a scorer callable object / function with signature
5256 ``scorer(estimator, X, y)``.
53-
57+ If ``None``, the ``score`` method of the estimator is used.
58+
5459 fit_params : dict, optional
5560 Parameters to pass to the fit method.
56-
57- .. deprecated:: 0.19
58- ``fit_params`` as a constructor argument was deprecated in version
59- 0.19 and will be removed in version 0.21. Pass fit parameters to
60- the ``fit`` method instead..
61-
62- n_jobs : int, default 1
63- This parameter is not used and kept for compatibility.
64-
61+
62+ n_jobs : int, default=1
63+ Number of jobs to run in parallel.
64+
6565 pre_dispatch : int, or string, optional
66- This parameter is not used and kept for compatibility.
67-
66+ Controls the number of jobs that get dispatched during parallel
67+ execution. Reducing this number can be useful to avoid an
68+ explosion of memory consumption when more jobs get dispatched
69+ than CPUs can process. This parameter can be:
70+ - None, in which case all the jobs are immediately
71+ created and spawned. Use this for lightweight and
72+ fast-running jobs, to avoid delays due to on-demand
73+ spawning of the jobs
74+ - An int, giving the exact number of total jobs that are
75+ spawned
76+ - A string, giving an expression as a function of n_jobs,
77+ as in '2*n_jobs'
78+
6879 iid : boolean, default=True
6980 If True, the data is assumed to be identically distributed across
7081 the folds, and the loss minimized is the total loss per sample,
7182 and not the mean loss across the folds.
72-
73- cv : integer or cross-validation generator, default=3
74- A cross-validation generator to use. If int, determines
75- the number of folds in StratifiedKFold if estimator is a classifier
76- and the target y is binary or multiclass, or the number
77- of folds in KFold otherwise.
78- Specific cross-validation objects can be passed, see
79- sklearn.cross_validation module for the list of possible objects.
80-
83+
84+ cv : int, cross-validation generator or an iterable, optional
85+ Determines the cross-validation splitting strategy.
86+ Possible inputs for cv are:
87+ - None, to use the default 3-fold cross validation,
88+ - integer, to specify the number of folds in a `(Stratified)KFold`,
89+ - An object to be used as a cross-validation generator.
90+ - An iterable yielding train, test splits.
91+ For integer/None inputs, if the estimator is a classifier and ``y`` is
92+ either binary or multiclass, :class:`StratifiedKFold` is used. In all
93+ other cases, :class:`KFold` is used.
94+ Refer :ref:`User Guide <cross_validation>` for the various
95+ cross-validation strategies that can be used here.
96+
8197 refit : boolean, default=True
8298 Refit the best estimator with the entire dataset.
8399 If "False", it is impossible to make predictions using
84100 this GridSearchCV instance after fitting.
85-
86- The refitting step, if any, happens on the local machine.
87-
101+
88102 verbose : integer
89103 Controls the verbosity: the higher, the more messages.
90-
104+
91105 error_score : 'raise' (default) or numeric
92106 Value to assign to the score if an error occurs in estimator fitting.
93107 If set to 'raise', the error is raised. If a numeric value is given,
94108 FitFailedWarning is raised. This parameter does not affect the refit
95109 step, which will always raise the error.
96-
97-
110+
111+ return_train_score : boolean, default=True
112+ If ``'False'``, the ``cv_results_`` attribute will not include training
113+ scores.
114+
98115 Examples
99116 --------
100117 >>> from sklearn import svm, datasets
101- >>> from spark_sklearn import GridSearchCV
102- >>> from pyspark.sql import SparkSession
103- >>> from spark_sklearn.util import createLocalSparkSession
104- >>> spark = createLocalSparkSession()
118+ >>> from sklearn.model_selection import GridSearchCV
105119 >>> iris = datasets.load_iris()
106120 >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
107121 >>> svr = svm.SVC()
108- >>> clf = GridSearchCV(spark.sparkContext, svr, parameters)
122+ >>> clf = GridSearchCV(svr, parameters)
109123 >>> clf.fit(iris.data, iris.target)
110124 ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
111125 GridSearchCV(cv=None, error_score=...,
@@ -115,10 +129,17 @@ class GridSearchCV(BaseSearchCV):
115129 random_state=None, shrinking=True, tol=...,
116130 verbose=False),
117131 fit_params={}, iid=..., n_jobs=1,
118- param_grid=..., pre_dispatch=..., refit=...,
132+ param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
119133 scoring=..., verbose=...)
120- >>> spark.stop(); SparkSession._instantiatedContext = None
121-
134+ >>> sorted(clf.cv_results_.keys())
135+ ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
136+ ['mean_fit_time', 'mean_score_time', 'mean_test_score',...
137+ 'mean_train_score', 'param_C', 'param_kernel', 'params',...
138+ 'rank_test_score', 'split0_test_score',...
139+ 'split0_train_score', 'split1_test_score', 'split1_train_score',...
140+ 'split2_test_score', 'split2_train_score',...
141+ 'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
142+
122143 Attributes
123144 ----------
124145 cv_results_ : dict of numpy (masked) ndarrays
@@ -188,29 +209,35 @@ class GridSearchCV(BaseSearchCV):
188209
189210 n_splits_ : int
190211 The number of cross-validation splits (folds/iterations).
191-
212+
192213 Notes
193214 ------
194215 The parameters selected are those that maximize the score of the left out
195216 data, unless an explicit score is passed in which case it is used instead.
196-
197- The parameters n_jobs and pre_dispatch are accepted but not used.
198-
217+ If `n_jobs` was set to a value higher than one, the data is copied for each
218+ point in the grid (and not `n_jobs` times). This is done for efficiency
219+ reasons if individual jobs take very little time, but may raise errors if
220+ the dataset is large and not enough memory is available. A workaround in
221+ this case is to set `pre_dispatch`. Then, the memory is copied only
222+ `pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
223+ n_jobs`.
224+
199225 See Also
200226 ---------
201227 :class:`ParameterGrid`:
202- generates all the combinations of a an hyperparameter grid.
203-
204- :func:`sklearn.cross_validation .train_test_split`:
228+ generates all the combinations of a hyperparameter grid.
229+
230+ :func:`sklearn.model_selection .train_test_split`:
205231 utility function to split the data into a development set usable
206232 for fitting a GridSearchCV instance and an evaluation set for
207233 its final evaluation.
208-
234+
209235 :func:`sklearn.metrics.make_scorer`:
210236 Make a scorer from a performance metric or loss function.
211-
237+
212238 """
213239
240+
214241 def __init__ (self , sc , estimator , param_grid , scoring = None , fit_params = None ,
215242 n_jobs = 1 , iid = True , refit = True , cv = None , verbose = 0 ,
216243 pre_dispatch = '2*n_jobs' , error_score = 'raise' , return_train_score = True ):
@@ -222,38 +249,28 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
222249
223250 self .cv_results_ = None
224251 _check_param_grid (param_grid )
225-
226- def fit_old (self , X , y = None ):
252+
253+ def fit (self , X , y = None , groups = None ):
227254 """Run fit with all sets of parameters.
228-
255+
229256 Parameters
230257 ----------
231-
258+
232259 X : array-like, shape = [n_samples, n_features]
233260 Training vector, where n_samples is the number of samples and
234261 n_features is the number of features.
235-
262+
236263 y : array-like, shape = [n_samples] or [n_samples, n_output], optional
237264 Target relative to X for classification or regression;
238265 None for unsupervised learning.
239-
266+
267+ groups : array-like, with shape (n_samples,), optional
268+ Group labels for the samples used while splitting the dataset into
269+ train/test set.
240270 """
241- return self ._fit (X , y , ParameterGrid (self .param_grid ))
271+ return self ._fit (X , y , groups , ParameterGrid (self .param_grid ))
242272
243-
244- def fit (self , X , y = None , groups = None , ** fit_params ):
245-
246- if self .fit_params is not None :
247- warnings .warn ('"fit_params" as a constructor argument was '
248- 'deprecated in version 0.19 and will be removed '
249- 'in version 0.21. Pass fit parameters to the '
250- '"fit" method instead.' , DeprecationWarning )
251- if fit_params :
252- warnings .warn ('Ignoring fit_params passed as a constructor '
253- 'argument in favor of keyword arguments to '
254- 'the "fit" method.' , RuntimeWarning )
255- else :
256- fit_params = self .fit_params
273+ def _fit (self , X , y , groups , parameter_iterable ):
257274
258275 estimator = self .estimator
259276 cv = check_cv (self .cv , y , classifier = is_classifier (estimator ))
@@ -262,18 +279,16 @@ def fit(self, X, y=None, groups=None, **fit_params):
262279
263280 X , y , groups = indexable (X , y , groups )
264281 n_splits = cv .get_n_splits (X , y , groups )
265- # Regenerate parameter iterable for each fit
266- candidate_params = ParameterGrid (self .param_grid )
267- n_candidates = len (candidate_params )
282+
268283 if self .verbose > 0 :
284+ n_candidates = len (parameter_iterable )
269285 print ("Fitting {0} folds for each of {1} candidates, totalling"
270286 " {2} fits" .format (n_splits , n_candidates ,
271287 n_candidates * n_splits ))
272288
273289 base_estimator = clone (self .estimator )
274290
275- param_grid = [(parameters , train , test ) for parameters , (train , test ) in product (candidate_params , cv .split (X , y , groups ))]
276-
291+ param_grid = [(parameters , train , test ) for parameters in parameter_iterable for train , test in list (cv .split (X , y , groups ))]
277292 # Because the original python code expects a certain order for the elements, we need to
278293 # respect it.
279294 indexed_param_grid = list (zip (range (len (param_grid )), param_grid ))
@@ -284,6 +299,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
284299 scorer = self .scorer_
285300 verbose = self .verbose
286301 error_score = self .error_score
302+ fit_params = self .fit_params
287303 return_train_score = self .return_train_score
288304 fas = _fit_and_score
289305
@@ -296,18 +312,21 @@ def fun(tup):
296312 parameters , fit_params ,
297313 return_train_score = return_train_score ,
298314 return_n_test_samples = True , return_times = True ,
299- return_parameters = False , error_score = error_score )
315+ return_parameters = True , error_score = error_score )
300316 return (index , res )
301317 indexed_out0 = dict (par_param_grid .map (fun ).collect ())
302318 out = [indexed_out0 [idx ] for idx in range (len (param_grid ))]
303319 if return_train_score :
304320 (train_scores , test_scores , test_sample_counts , fit_time ,
305- score_time ) = zip (* out )
321+ score_time , parameters ) = zip (* out )
306322 else :
307- (test_scores , test_sample_counts , fit_time , score_time ) = zip (* out )
323+ (test_scores , test_sample_counts , fit_time , score_time , parameters ) = zip (* out )
308324 X_bc .unpersist ()
309325 y_bc .unpersist ()
310326
327+ candidate_params = parameters [::n_splits ]
328+ n_candidates = len (candidate_params )
329+
311330 results = dict ()
312331
313332 def _store (key_name , array , weights = None , splits = False , rank = False ):
0 commit comments