|
2 | 2 | Class for parallelizing GridSearchCV jobs in scikit-learn |
3 | 3 | """ |
4 | 4 |
|
5 | | -import sys |
6 | | - |
7 | | -from itertools import product |
8 | 5 | from collections import defaultdict, Sized |
9 | 6 | from functools import partial |
10 | 7 | import warnings |
@@ -115,21 +112,23 @@ class GridSearchCV(BaseSearchCV): |
115 | 112 | Examples |
116 | 113 | -------- |
117 | 114 | >>> from sklearn import svm, datasets |
118 | | - >>> from sklearn.model_selection import GridSearchCV |
| 115 | + >>> from spark_sklearn.grid_search import GridSearchCV |
| 116 | + >>> from spark_sklearn.util import createLocalSparkSession |
| 117 | + >>> sc = createLocalSparkSession().sparkContext |
119 | 118 | >>> iris = datasets.load_iris() |
120 | 119 | >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]} |
121 | 120 | >>> svr = svm.SVC() |
122 | | - >>> clf = GridSearchCV(svr, parameters) |
| 121 | + >>> clf = GridSearchCV(sc, svr, parameters) |
123 | 122 | >>> clf.fit(iris.data, iris.target) |
124 | 123 | ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS |
125 | 124 | GridSearchCV(cv=None, error_score=..., |
126 | 125 | estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., |
127 | | - decision_function_shape=None, degree=..., gamma=..., |
| 126 | + decision_function_shape=..., degree=..., gamma=..., |
128 | 127 | kernel='rbf', max_iter=-1, probability=False, |
129 | 128 | random_state=None, shrinking=True, tol=..., |
130 | 129 | verbose=False), |
131 | 130 | fit_params={}, iid=..., n_jobs=1, |
132 | | - param_grid=..., pre_dispatch=..., refit=..., return_train_score=..., |
| 131 | + param_grid=..., pre_dispatch=..., refit=..., |
133 | 132 | scoring=..., verbose=...) |
134 | 133 | >>> sorted(clf.cv_results_.keys()) |
135 | 134 | ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS |
@@ -242,9 +241,12 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None, |
242 | 241 | n_jobs=1, iid=True, refit=True, cv=None, verbose=0, |
243 | 242 | pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True): |
244 | 243 | super(GridSearchCV, self).__init__( |
245 | | - estimator=estimator, scoring=scoring, fit_params=fit_params, n_jobs=n_jobs, iid=iid, |
| 244 | + estimator=estimator, scoring=scoring, n_jobs=n_jobs, iid=iid, |
246 | 245 | refit=refit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score, |
247 | 246 | return_train_score=return_train_score) |
| 247 | + |
| 248 | + self.fit_params = fit_params if fit_params is not None else {} |
| 249 | + |
248 | 250 | self.sc = sc |
249 | 251 | self.param_grid = param_grid |
250 | 252 |
|
|
0 commit comments