@@ -305,15 +305,6 @@ def score(self, *args, **kwargs):
305305 """
306306 raise NotImplementedError ("Abstract method" )
307307
308- @property
309- @abc .abstractmethod
310- def needs_fit (self ):
311- """
312- Whether the model selector needs to be fit before it can be used for prediction or scoring;
313- in many cases this is equivalent to whether the selector is choosing between multiple models
314- """
315- raise NotImplementedError ("Abstract method" )
316-
317308
318309class SingleModelSelector (ModelSelector ):
319310 """
@@ -392,24 +383,25 @@ class FixedModelSelector(SingleModelSelector):
392383 Model selection class that always selects the given sklearn-compatible model
393384 """
394385
395- def __init__ (self , model ):
386+ def __init__ (self , model , score_during_selection ):
396387 self .model = clone (model , safe = False )
388+ self .score_during_selection = score_during_selection
397389
398390 def train (self , is_selecting , folds : Optional [List ], X , y , groups = None , ** kwargs ):
399391 if is_selecting :
400- # since needs_fit is False, is_selecting will only be true if
401- # the score needs to be compared to another model's
402- # so we don't need to fit the model itself, just get the out-of-sample score
403- assert hasattr (self .model , 'score' ), (f"Can't select between a fixed { type (self .model )} model and others "
404- " because it doesn't have a score method" )
405- scores = []
406- for train , test in folds :
407- # use _fit_with_groups instead of just fit to handle nested grouping
408- _fit_with_groups (self .model , X [train ], y [train ],
409- groups = None if groups is None else groups [train ],
410- ** {key : val [train ] for key , val in kwargs .items ()})
411- scores .append (self .model .score (X [test ], y [test ]))
412- self ._score = np .mean (scores )
392+ if self . score_during_selection :
393+ # the score needs to be compared to another model's
394+ # so we don't need to fit the model itself on all of the data , just get the out-of-sample score
395+ assert hasattr (self .model , 'score' ), (f"Can't select between a fixed { type (self .model )} model "
396+ "and others because it doesn't have a score method" )
397+ scores = []
398+ for train , test in folds :
399+ # use _fit_with_groups instead of just fit to handle nested grouping
400+ _fit_with_groups (self .model , X [train ], y [train ],
401+ groups = None if groups is None else groups [train ],
402+ ** {key : val [train ] for key , val in kwargs .items ()})
403+ scores .append (self .model .score (X [test ], y [test ]))
404+ self ._score = np .mean (scores )
413405 else :
414406 # we need to train the model on the data
415407 _fit_with_groups (self .model , X , y , groups = groups , ** kwargs )
@@ -422,11 +414,10 @@ def best_model(self):
422414
423415 @property
424416 def best_score (self ):
425- return self ._score
426-
427- @property
428- def needs_fit (self ):
429- return False # We have only a single model so we can skip the selection process
417+ if hasattr (self , '_score' ):
418+ return self ._score
419+ else :
420+ raise ValueError ("No score was computed during selection" )
430421
431422
432423def _copy_to (m1 , m2 , attrs , insert_underscore = False ):
@@ -579,11 +570,6 @@ def best_model(self):
579570 def best_score (self ):
580571 return self ._best_score
581572
582- @property
583- def needs_fit (self ):
584- return True # strictly speaking, could be false if the hyperparameters are fixed
585- # but it would be complicated to check that
586-
587573
588574class ListSelector (SingleModelSelector ):
589575 """
@@ -627,14 +613,8 @@ def best_model(self):
627613 def best_score (self ):
628614 return self ._best_score
629615
630- @property
631- def needs_fit (self ):
632- # technically, if there is just one model and it doesn't need to be fit we don't need to fit it,
633- # but that complicates the training logic so we don't bother with that case
634- return True
635-
636616
637- def get_selector (input , is_discrete , * , random_state = None , cv = None , wrapper = GridSearchCV ):
617+ def get_selector (input , is_discrete , * , random_state = None , cv = None , wrapper = GridSearchCV , needs_scoring = False ):
638618 named_models = {
639619 'linear' : (LogisticRegressionCV (random_state = random_state , cv = cv ) if is_discrete
640620 else WeightedLassoCVWrapper (random_state = random_state , cv = cv )),
@@ -657,19 +637,21 @@ def get_selector(input, is_discrete, *, random_state=None, cv=None, wrapper=Grid
657637 return input
658638 elif isinstance (input , list ): # we've got a list; call get_selector on each element, then wrap in a ListSelector
659639 models = [get_selector (model , is_discrete ,
660- random_state = random_state , cv = cv , wrapper = wrapper )
640+ random_state = random_state , cv = cv , wrapper = wrapper ,
641+ needs_scoring = True ) # we need to score to compare outputs to each other
661642 for model in input ]
662643 return ListSelector (models )
663644 elif isinstance (input , str ): # we've got a string; look it up
664645 if input in named_models :
665646 return get_selector (named_models [input ], is_discrete ,
666- random_state = random_state , cv = cv , wrapper = wrapper )
647+ random_state = random_state , cv = cv , wrapper = wrapper ,
648+ needs_scoring = needs_scoring )
667649 else :
668650 raise ValueError (f"Unknown model type: { input } , must be one of { named_models .keys ()} " )
669651 elif SklearnCVSelector .can_wrap (input ):
670652 return SklearnCVSelector (input )
671653 else : # assume this is an sklearn-compatible model
672- return FixedModelSelector (input )
654+ return FixedModelSelector (input , needs_scoring )
673655
674656
675657class GridSearchCVList (BaseEstimator ):
0 commit comments