Skip to content
This repository was archived by the owner on Dec 4, 2019. It is now read-only.

Commit 1980bce

Browse files
committed
Made changes according to pull request comments.
1 parent eba5bb9 commit 1980bce

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

python/spark_sklearn/grid_search.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66

77
from itertools import product
8-
from collections import defaultdict
8+
from collections import defaultdict, Sized
99
from functools import partial
1010
import warnings
1111

@@ -242,8 +242,9 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
242242
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
243243
pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True):
244244
super(GridSearchCV, self).__init__(
245-
estimator, scoring, fit_params, n_jobs, iid,
246-
refit, cv, verbose, pre_dispatch, error_score, return_train_score)
245+
estimator=estimator, scoring=scoring, fit_params=fit_params, n_jobs=n_jobs, iid=iid,
246+
refit=retfit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score,
247+
return_train_score=return_train_score)
247248
self.sc = sc
248249
self.param_grid = param_grid
249250

@@ -280,15 +281,16 @@ def _fit(self, X, y, groups, parameter_iterable):
280281
X, y, groups = indexable(X, y, groups)
281282
n_splits = cv.get_n_splits(X, y, groups)
282283

283-
if self.verbose > 0:
284+
if self.verbose > 0 and isinstance(parameter_iterable, Sized):
284285
n_candidates = len(parameter_iterable)
285286
print("Fitting {0} folds for each of {1} candidates, totalling"
286287
" {2} fits".format(n_splits, n_candidates,
287288
n_candidates * n_splits))
288289

289290
base_estimator = clone(self.estimator)
290291

291-
param_grid = [(parameters, train, test) for parameters in parameter_iterable for train, test in list(cv.split(X, y, groups))]
292+
param_grid = [(parameters, train, test) for parameters in parameter_iterable
293+
for train, test in list(cv.split(X, y, groups))]
292294
# Because the original python code expects a certain order for the elements, we need to
293295
# respect it.
294296
indexed_param_grid = list(zip(range(len(param_grid)), param_grid))
@@ -309,10 +311,10 @@ def fun(tup):
309311
local_X = X_bc.value
310312
local_y = y_bc.value
311313
res = fas(local_estimator, local_X, local_y, scorer, train, test, verbose,
312-
parameters, fit_params,
313-
return_train_score=return_train_score,
314-
return_n_test_samples=True, return_times=True,
315-
return_parameters=True, error_score=error_score)
314+
parameters, fit_params,
315+
return_train_score=return_train_score,
316+
return_n_test_samples=True, return_times=True,
317+
return_parameters=True, error_score=error_score)
316318
return (index, res)
317319
indexed_out0 = dict(par_param_grid.map(fun).collect())
318320
out = [indexed_out0[idx] for idx in range(len(param_grid))]

0 commit comments

Comments
 (0)