Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit 1a4cd4e

Browse files
authored
Merge pull request #64 from smurching/spark-2.2-patch
[#63] Remove usage of deprecated scikit-learn API in GridSearchCV
2 parents 770484a + 41cf622 commit 1a4cd4e

File tree

5 files changed

+21
-12
lines changed

5 files changed

+21
-12
lines changed

python/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ More extensive documentation (generated with Sphinx) is available in the `python
6565
## Changelog
6666

6767
- 2015-12-10 First public release (0.1)
68-
- 2016-08-16 Minor release:
69-
1. the official Spark target is Spark 0.2
68+
- 2016-08-16 Minor release (0.2.0):
69+
1. the official Spark target is Spark 2.0
7070
2. support for keyed models
71+
- 2017-09-14 Minor release (0.2.2):
72+
1. The official Spark target is Spark >= 2.1
73+
7174

python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# This file should list any python package dependencies.
2-
scikit-learn==0.18.1
2+
scikit-learn>=0.18.1, <=0.19

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"Programming Language :: Python",
2020
"Topic :: Scientific/Engineering",
2121
]
22-
INSTALL_REQUIRES = ["scikit-learn >= 0.18.1"]
22+
INSTALL_REQUIRES = ["scikit-learn >=0.18.1, <= 0.19"]
2323

2424
# Project root
2525
ROOT = os.path.abspath(os.getcwd() + "/")

python/spark_sklearn/grid_search.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
Class for parallelizing GridSearchCV jobs in scikit-learn
33
"""
44

5-
import sys
6-
7-
from itertools import product
85
from collections import defaultdict, Sized
96
from functools import partial
107
import warnings
@@ -115,21 +112,23 @@ class GridSearchCV(BaseSearchCV):
115112
Examples
116113
--------
117114
>>> from sklearn import svm, datasets
118-
>>> from sklearn.model_selection import GridSearchCV
115+
>>> from spark_sklearn.grid_search import GridSearchCV
116+
>>> from spark_sklearn.util import createLocalSparkSession
117+
>>> sc = createLocalSparkSession().sparkContext
119118
>>> iris = datasets.load_iris()
120119
>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
121120
>>> svr = svm.SVC()
122-
>>> clf = GridSearchCV(svr, parameters)
121+
>>> clf = GridSearchCV(sc, svr, parameters)
123122
>>> clf.fit(iris.data, iris.target)
124123
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
125124
GridSearchCV(cv=None, error_score=...,
126125
estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=...,
127-
decision_function_shape=None, degree=..., gamma=...,
126+
decision_function_shape=..., degree=..., gamma=...,
128127
kernel='rbf', max_iter=-1, probability=False,
129128
random_state=None, shrinking=True, tol=...,
130129
verbose=False),
131130
fit_params={}, iid=..., n_jobs=1,
132-
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
131+
param_grid=..., pre_dispatch=..., refit=...,
133132
scoring=..., verbose=...)
134133
>>> sorted(clf.cv_results_.keys())
135134
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
@@ -242,9 +241,12 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
242241
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
243242
pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True):
244243
super(GridSearchCV, self).__init__(
245-
estimator=estimator, scoring=scoring, fit_params=fit_params, n_jobs=n_jobs, iid=iid,
244+
estimator=estimator, scoring=scoring, n_jobs=n_jobs, iid=iid,
246245
refit=refit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score,
247246
return_train_score=return_train_score)
247+
248+
self.fit_params = fit_params if fit_params is not None else {}
249+
248250
self.sc = sc
249251
self.param_grid = param_grid
250252

python/spark_sklearn/tests/test_grid_search_1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def do_test_expected(*kwargs):
3232
return do_test_expected
3333

3434
def _add_to_module():
35+
# NOTE: This doesn't actually run scikit-learn tests against SPGridSearchWrapper
36+
# for scikit-learn >= 0.18, since the scikit-learn tests (in sklearn.model_selection.tests) use
37+
# sklearn.model_selection.GridSearchCV (not sklearn.grid_search.GridSearchCV)
38+
# TODO: Get scikit-learn tests to pass with spark-sklearn GridSearch implementation
3539
SKGridSearchCV = sklearn.grid_search.GridSearchCV
3640
sklearn.grid_search.GridSearchCV = SPGridSearchWrapper
3741
sklearn.grid_search.GridSearchCV_original = SKGridSearchCV

0 commit comments

Comments
 (0)