55import sys
66
77from itertools import product
8- from collections import defaultdict
8+ from collections import defaultdict , Sized
99from functools import partial
1010import warnings
1111
@@ -242,8 +242,9 @@ def __init__(self, sc, estimator, param_grid, scoring=None, fit_params=None,
242242 n_jobs = 1 , iid = True , refit = True , cv = None , verbose = 0 ,
243243 pre_dispatch = '2*n_jobs' , error_score = 'raise' , return_train_score = True ):
244244 super (GridSearchCV , self ).__init__ (
245- estimator , scoring , fit_params , n_jobs , iid ,
246- refit , cv , verbose , pre_dispatch , error_score , return_train_score )
245+ estimator = estimator , scoring = scoring , fit_params = fit_params , n_jobs = n_jobs , iid = iid ,
246+ refit = retfit , cv = cv , verbose = verbose , pre_dispatch = pre_dispatch , error_score = error_score ,
247+ return_train_score = return_train_score )
247248 self .sc = sc
248249 self .param_grid = param_grid
249250
@@ -280,15 +281,16 @@ def _fit(self, X, y, groups, parameter_iterable):
280281 X , y , groups = indexable (X , y , groups )
281282 n_splits = cv .get_n_splits (X , y , groups )
282283
283- if self .verbose > 0 :
284+ if self .verbose > 0 and isinstance ( parameter_iterable , Sized ) :
284285 n_candidates = len (parameter_iterable )
285286 print ("Fitting {0} folds for each of {1} candidates, totalling"
286287 " {2} fits" .format (n_splits , n_candidates ,
287288 n_candidates * n_splits ))
288289
289290 base_estimator = clone (self .estimator )
290291
291- param_grid = [(parameters , train , test ) for parameters in parameter_iterable for train , test in list (cv .split (X , y , groups ))]
292+ param_grid = [(parameters , train , test ) for parameters in parameter_iterable
293+ for train , test in list (cv .split (X , y , groups ))]
292294 # Because the original python code expects a certain order for the elements, we need to
293295 # respect it.
294296 indexed_param_grid = list (zip (range (len (param_grid )), param_grid ))
@@ -309,10 +311,10 @@ def fun(tup):
309311 local_X = X_bc .value
310312 local_y = y_bc .value
311313 res = fas (local_estimator , local_X , local_y , scorer , train , test , verbose ,
312- parameters , fit_params ,
313- return_train_score = return_train_score ,
314- return_n_test_samples = True , return_times = True ,
315- return_parameters = True , error_score = error_score )
314+ parameters , fit_params ,
315+ return_train_score = return_train_score ,
316+ return_n_test_samples = True , return_times = True ,
317+ return_parameters = True , error_score = error_score )
316318 return (index , res )
317319 indexed_out0 = dict (par_param_grid .map (fun ).collect ())
318320 out = [indexed_out0 [idx ] for idx in range (len (param_grid ))]
0 commit comments