Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit 8fead94

Browse files
committed
Responded to pull request comments. Modified to match sklearn 0.18.1
1 parent 5a2af5f commit 8fead94

File tree

2 files changed

+110
-91
lines changed

2 files changed

+110
-91
lines changed

python/spark_sklearn/grid_search.py

Lines changed: 108 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66

77
from itertools import product
8-
from collections import Sized, Mapping, namedtuple, defaultdict, Sequence
8+
from collections import defaultdict
99
from functools import partial
1010
import warnings
1111

@@ -23,89 +23,103 @@
2323

2424

2525
class GridSearchCV(BaseSearchCV):
26-
"""Exhaustive search over specified parameter values for an estimator, using Spark to
27-
distribute the computations.
28-
26+
"""Exhaustive search over specified parameter values for an estimator.
27+
2928
Important members are fit, predict.
30-
31-
GridSearchCV implements a "fit" method and a "predict" method like
32-
any classifier except that the parameters of the classifier
33-
used to predict is optimized by cross-validation.
34-
29+
30+
GridSearchCV implements a "fit" and a "score" method.
31+
It also implements "predict", "predict_proba", "decision_function",
32+
"transform" and "inverse_transform" if they are implemented in the
33+
estimator used.
34+
35+
The parameters of the estimator used to apply these methods are optimized
36+
by cross-validated grid-search over a parameter grid.
37+
Read more in the :ref:`User Guide <grid_search>`.
38+
3539
Parameters
3640
----------
37-
sc: the spark context
38-
39-
estimator : object type that implements the "fit" and "predict" methods
40-
A object of that type is instantiated for each grid point.
41-
41+
estimator : estimator object.
42+
This is assumed to implement the scikit-learn estimator interface.
43+
Either estimator needs to provide a ``score`` function,
44+
or ``scoring`` must be passed.
45+
4246
param_grid : dict or list of dictionaries
4347
Dictionary with parameters names (string) as keys and lists of
4448
parameter settings to try as values, or a list of such
4549
dictionaries, in which case the grids spanned by each dictionary
4650
in the list are explored. This enables searching over any sequence
4751
of parameter settings.
48-
49-
scoring : string, callable or None, optional, default: None
52+
53+
scoring : string, callable or None, default=None
5054
A string (see model evaluation documentation) or
5155
a scorer callable object / function with signature
5256
``scorer(estimator, X, y)``.
53-
57+
If ``None``, the ``score`` method of the estimator is used.
58+
5459
fit_params : dict, optional
5560
Parameters to pass to the fit method.
56-
57-
.. deprecated:: 0.19
58-
``fit_params`` as a constructor argument was deprecated in version
59-
0.19 and will be removed in version 0.21. Pass fit parameters to
60-
the ``fit`` method instead..
61-
62-
n_jobs : int, default 1
63-
This parameter is not used and kept for compatibility.
64-
61+
62+
n_jobs : int, default=1
63+
Number of jobs to run in parallel.
64+
6565
pre_dispatch : int, or string, optional
66-
This parameter is not used and kept for compatibility.
67-
66+
Controls the number of jobs that get dispatched during parallel
67+
execution. Reducing this number can be useful to avoid an
68+
explosion of memory consumption when more jobs get dispatched
69+
than CPUs can process. This parameter can be:
70+
- None, in which case all the jobs are immediately
71+
created and spawned. Use this for lightweight and
72+
fast-running jobs, to avoid delays due to on-demand
73+
spawning of the jobs
74+
- An int, giving the exact number of total jobs that are
75+
spawned
76+
- A string, giving an expression as a function of n_jobs,
77+
as in '2*n_jobs'
78+
6879
iid : boolean, default=True
6980
If True, the data is assumed to be identically distributed across
7081
the folds, and the loss minimized is the total loss per sample,
7182
and not the mean loss across the folds.
72-
73-
cv : integer or cross-validation generator, default=3
74-
A cross-validation generator to use. If int, determines
75-
the number of folds in StratifiedKFold if estimator is a classifier
76-
and the target y is binary or multiclass, or the number
77-
of folds in KFold otherwise.
78-
Specific cross-validation objects can be passed, see
79-
sklearn.cross_validation module for the list of possible objects.
80-
83+
84+
cv : int, cross-validation generator or an iterable, optional
85+
Determines the cross-validation splitting strategy.
86+
Possible inputs for cv are:
87+
- None, to use the default 3-fold cross validation,
88+
- integer, to specify the number of folds in a `(Stratified)KFold`,
89+
- An object to be used as a cross-validation generator.
90+
- An iterable yielding train, test splits.
91+
For integer/None inputs, if the estimator is a classifier and ``y`` is
92+
either binary or multiclass, :class:`StratifiedKFold` is used. In all
93+
other cases, :class:`KFold` is used.
94+
Refer :ref:`User Guide <cross_validation>` for the various
95+
cross-validation strategies that can be used here.
96+
8197
refit : boolean, default=True
8298
Refit the best estimator with the entire dataset.
8399
If "False", it is impossible to make predictions using
84100
this GridSearchCV instance after fitting.
85-
86-
The refitting step, if any, happens on the local machine.
87-
101+
88102
verbose : integer
89103
Controls the verbosity: the higher, the more messages.
90-
104+
91105
error_score : 'raise' (default) or numeric
92106
Value to assign to the score if an error occurs in estimator fitting.
93107
If set to 'raise', the error is raised. If a numeric value is given,
94108
FitFailedWarning is raised. This parameter does not affect the refit
95109
step, which will always raise the error.
96-
97-
110+
111+
return_train_score : boolean, default=True
112+
If ``'False'``, the ``cv_results_`` attribute will not include training
113+
scores.
114+
98115
Examples
99116
--------
100117
>>> from sklearn import svm, datasets
101-
>>> from spark_sklearn import GridSearchCV
102-
>>> from pyspark.sql import SparkSession
103-
>>> from spark_sklearn.util import createLocalSparkSession
104-
>>> spark = createLocalSparkSession()
118+
>>> from sklearn.model_selection import GridSearchCV
105119
>>> iris = datasets.load_iris()
106120
>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
107121
>>> svr = svm.SVC()
108-
>>> clf = GridSearchCV(spark.sparkContext, svr, parameters)
122+
>>> clf = GridSearchCV(svr, parameters)
109123
>>> clf.fit(iris.data, iris.target)
110124
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
111125
GridSearchCV(cv=None, error_score=...,
@@ -115,10 +129,17 @@ class GridSearchCV(BaseSearchCV):
115129
random_state=None, shrinking=True, tol=...,
116130
verbose=False),
117131
fit_params={}, iid=..., n_jobs=1,
118-
param_grid=..., pre_dispatch=..., refit=...,
132+
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
119133
scoring=..., verbose=...)
120-
>>> spark.stop(); SparkSession._instantiatedContext = None
121-
134+
>>> sorted(clf.cv_results_.keys())
135+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
136+
['mean_fit_time', 'mean_score_time', 'mean_test_score',...
137+
'mean_train_score', 'param_C', 'param_kernel', 'params',...
138+
'rank_test_score', 'split0_test_score',...
139+
'split0_train_score', 'split1_test_score', 'split1_train_score',...
140+
'split2_test_score', 'split2_train_score',...
141+
'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
142+
122143
Attributes
123144
----------
124145
cv_results_ : dict of numpy (masked) ndarrays
@@ -188,29 +209,35 @@ class GridSearchCV(BaseSearchCV):
188209
189210
n_splits_ : int
190211
The number of cross-validation splits (folds/iterations).
191-
212+
192213
Notes
193214
------
194215
The parameters selected are those that maximize the score of the left out
195216
data, unless an explicit score is passed in which case it is used instead.
196-
197-
The parameters n_jobs and pre_dispatch are accepted but not used.
198-
217+
If `n_jobs` was set to a value higher than one, the data is copied for each
218+
point in the grid (and not `n_jobs` times). This is done for efficiency
219+
reasons if individual jobs take very little time, but may raise errors if
220+
the dataset is large and not enough memory is available. A workaround in
221+
this case is to set `pre_dispatch`. Then, the memory is copied only
222+
`pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
223+
n_jobs`.
224+
199225
See Also
200226
---------
201227
:class:`ParameterGrid`:
202-
generates all the combinations of a an hyperparameter grid.
203-
204-
:func:`sklearn.cross_validation.train_test_split`:
228+
generates all the combinations of a hyperparameter grid.
229+
230+
:func:`sklearn.model_selection.train_test_split`:
205231
utility function to split the data into a development set usable
206232
for fitting a GridSearchCV instance and an evaluation set for
207233
its final evaluation.
208-
234+
209235
:func:`sklearn.metrics.make_scorer`:
210236
Make a scorer from a performance metric or loss function.
211-
237+
212238
"""
213239

240+
214241
def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
215242
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
216243
pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True):
@@ -222,38 +249,28 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
222249

223250
self.cv_results_ = None
224251
_check_param_grid(param_grid)
225-
226-
def fit_old(self, X, y=None):
252+
253+
def fit(self, X, y=None, groups=None):
227254
"""Run fit with all sets of parameters.
228-
255+
229256
Parameters
230257
----------
231-
258+
232259
X : array-like, shape = [n_samples, n_features]
233260
Training vector, where n_samples is the number of samples and
234261
n_features is the number of features.
235-
262+
236263
y : array-like, shape = [n_samples] or [n_samples, n_output], optional
237264
Target relative to X for classification or regression;
238265
None for unsupervised learning.
239-
266+
267+
groups : array-like, with shape (n_samples,), optional
268+
Group labels for the samples used while splitting the dataset into
269+
train/test set.
240270
"""
241-
return self._fit(X, y, ParameterGrid(self.param_grid))
271+
return self._fit(X, y, groups, ParameterGrid(self.param_grid))
242272

243-
244-
def fit(self, X, y=None, groups=None, **fit_params):
245-
246-
if self.fit_params is not None:
247-
warnings.warn('"fit_params" as a constructor argument was '
248-
'deprecated in version 0.19 and will be removed '
249-
'in version 0.21. Pass fit parameters to the '
250-
'"fit" method instead.', DeprecationWarning)
251-
if fit_params:
252-
warnings.warn('Ignoring fit_params passed as a constructor '
253-
'argument in favor of keyword arguments to '
254-
'the "fit" method.', RuntimeWarning)
255-
else:
256-
fit_params = self.fit_params
273+
def _fit(self, X, y, groups, parameter_iterable):
257274

258275
estimator = self.estimator
259276
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
@@ -262,18 +279,16 @@ def fit(self, X, y=None, groups=None, **fit_params):
262279

263280
X, y, groups = indexable(X, y, groups)
264281
n_splits = cv.get_n_splits(X, y, groups)
265-
# Regenerate parameter iterable for each fit
266-
candidate_params = ParameterGrid(self.param_grid)
267-
n_candidates = len(candidate_params)
282+
268283
if self.verbose > 0:
284+
n_candidates = len(parameter_iterable)
269285
print("Fitting {0} folds for each of {1} candidates, totalling"
270286
" {2} fits".format(n_splits, n_candidates,
271287
n_candidates * n_splits))
272288

273289
base_estimator = clone(self.estimator)
274290

275-
param_grid = [(parameters, train, test) for parameters, (train, test) in product(candidate_params, cv.split(X, y, groups))]
276-
291+
param_grid = [(parameters, train, test) for parameters in parameter_iterable for train, test in list(cv.split(X, y, groups))]
277292
# Because the original python code expects a certain order for the elements, we need to
278293
# respect it.
279294
indexed_param_grid = list(zip(range(len(param_grid)), param_grid))
@@ -284,6 +299,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
284299
scorer = self.scorer_
285300
verbose = self.verbose
286301
error_score = self.error_score
302+
fit_params = self.fit_params
287303
return_train_score = self.return_train_score
288304
fas = _fit_and_score
289305

@@ -296,18 +312,21 @@ def fun(tup):
296312
parameters, fit_params,
297313
return_train_score=return_train_score,
298314
return_n_test_samples=True, return_times=True,
299-
return_parameters=False, error_score=error_score)
315+
return_parameters=True, error_score=error_score)
300316
return (index, res)
301317
indexed_out0 = dict(par_param_grid.map(fun).collect())
302318
out = [indexed_out0[idx] for idx in range(len(param_grid))]
303319
if return_train_score:
304320
(train_scores, test_scores, test_sample_counts, fit_time,
305-
score_time) = zip(*out)
321+
score_time, parameters) = zip(*out)
306322
else:
307-
(test_scores, test_sample_counts, fit_time, score_time) = zip(*out)
323+
(test_scores, test_sample_counts, fit_time, score_time, parameters) = zip(*out)
308324
X_bc.unpersist()
309325
y_bc.unpersist()
310326

327+
candidate_params = parameters[::n_splits]
328+
n_candidates = len(candidate_params)
329+
311330
results = dict()
312331

313332
def _store(key_name, array, weights=None, splits=False, rank=False):

python/spark_sklearn/tests/test_grid_search_1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
3232
'test_grid_search_precomputed_kernel_error_kernel_function',
3333
'test_grid_search_precomputed_kernel',
3434
'test_grid_search_failing_classifier_raise',
35-
'test_grid_search_score_method', # added this because the sklearn implementation of fit() fails it
35+
'test_grid_search_score_method', # added this because the sklearn implementation of fit() fails it'
3636
'test_grid_search_failing_classifier']) # This one we should investigate
3737

3838
def _create_method(method):
@@ -53,4 +53,4 @@ def _add_to_module():
5353
method_for_test.__name__ = name
5454
setattr (AllTests, method.__name__, method_for_test)
5555

56-
_add_to_module()
56+
_add_to_module()

0 commit comments

Comments
 (0)