Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit 3dc69d9

Browse files
authored
[53] Update to sklearn 0.18.1
Currently, spark-sklearn gives deprecation warnings when used with sklearn version .18 because several classes in grid_search and cross_validation were refactored into a new module called model_selection. The changes here make spark-sklearn compatible with the changes introduced in sklearn .18. The most critical changes reflected in the new version of sklearn is that: sklearn.model_selection.GridSearchCV now has: cv_results_ : dict of numpy (masked) ndarrays - A dict with keys as column headers and values as columns, that can be imported into a pandas DataFrame. best_estimator_ : estimator - Estimator that was chosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) on the left out data. Not available if refit=False. best_score_ : float - Score of best_estimator on the left out data. best_params_ : dict - Parameter setting that gave the best results on the hold out data. best_index_ : int - The index (of the cv_results_ arrays) which corresponds to the best candidate parameter setting. scorer_ : function - Scorer function used on the held out data to choose the best parameters for the model. n_splits_ : int - The number of cross-validation splits (folds/iterations). While spark-sklearn.GridSearchCV has: grid_scores_ : list of named tuples best_estimator_ : estimator - Estimator that was chosen by the search, i.e. estimator which gave highest score (or smallest loss if specified) on the left out data. Not available if refit=False. best_score_ : float - Score of best_estimator on the left out data. best_params_ : dict - Parameter setting that gave the best results on the hold out data. scorer_ : function - Scorer function used on the held out data to choose the best parameters for the model. The biggest is that sklearn added the more comprehensive cv_results_ which adds data that the formerly compatible grid_scores_ is lacking. Note: This version of spark-sklearn is not compatible with sklearn <= .17.
2 parents d6f6f56 + eeac339 commit 3dc69d9

File tree

7 files changed

+275
-175
lines changed

7 files changed

+275
-175
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ install:
3232
- conda info -a
3333

3434
# Replace dep1 dep2 ... with your dependencies
35-
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION scikit-learn==0.17.0 nose=1.3.7 pandas=0.18
35+
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION scikit-learn==0.18.1 nose=1.3.7 pandas=0.18
3636
- source activate test-environment
3737

3838
script:

python/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ This package is released under the Apache 2.0 license. See the LICENSE file.
2020
## Installation
2121

2222
This package has the following requirements:
23-
- a recent version of scikit-learn. Version 0.17 has been tested, older versions may work too.
23+
This package has the following requirements:
24+
- Sklearn version >= 0.18.1
2425
- Spark >= 2.1.1 Spark may be downloaded from the
2526
[Spark official website](http://spark.apache.org/). In order to use spark-sklearn, you need to use the pyspark interpreter or another Spark-compliant python interpreter. See the [Spark guide](https://spark.apache.org/docs/latest/programming-guide.html#overview) for more details.
2627
- [nose](https://nose.readthedocs.org) (testing dependency only)

python/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# This file should list any python package dependencies.
2-
scikit-learn==0.17.1
2+
scikit-learn==0.18.1

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"Programming Language :: Python",
2020
"Topic :: Scientific/Engineering",
2121
]
22-
INSTALL_REQUIRES = ["scikit-learn >= 0.17"]
22+
INSTALL_REQUIRES = ["scikit-learn >= 0.18.1"]
2323

2424
# Project root
2525
ROOT = os.path.abspath(os.getcwd() + "/")

0 commit comments

Comments
 (0)