4545)
4646from autosklearn .evaluation import ExecuteTaFuncWithQueue , get_cost_of_crash
4747from autosklearn .evaluation .abstract_evaluator import _fit_and_suppress_warnings
48- from autosklearn .evaluation .train_evaluator import _fit_with_budget
48+ from autosklearn .evaluation .train_evaluator import TrainEvaluator , _fit_with_budget
4949from autosklearn .metrics import calculate_metric
5050from autosklearn .util .backend import Backend
5151from autosklearn .util .stopwatch import StopWatch
@@ -139,13 +139,13 @@ def __init__(self,
139139 smac_scenario_args = None ,
140140 logging_config = None ,
141141 metric = None ,
142- scoring_functions = None
142+ scoring_functions = None ,
143+ get_trials_callback = None
143144 ):
144145 super (AutoML , self ).__init__ ()
145146 self .configuration_space = None
146147 self ._backend = backend
147148 # self._tmp_dir = tmp_dir
148- # self._output_dir = output_dir
149149 self ._time_for_task = time_left_for_this_task
150150 self ._per_run_time_limit = per_run_time_limit
151151 self ._initial_configurations_via_metalearning = \
@@ -165,32 +165,6 @@ def __init__(self,
165165 self ._scoring_functions = scoring_functions if scoring_functions is not None else []
166166 self ._resampling_strategy_arguments = resampling_strategy_arguments \
167167 if resampling_strategy_arguments is not None else {}
168- if self ._resampling_strategy not in ['holdout' ,
169- 'holdout-iterative-fit' ,
170- 'cv' ,
171- 'cv-iterative-fit' ,
172- 'partial-cv' ,
173- 'partial-cv-iterative-fit' ,
174- ] \
175- and not issubclass (self ._resampling_strategy , BaseCrossValidator )\
176- and not issubclass (self ._resampling_strategy , _RepeatedSplits )\
177- and not issubclass (self ._resampling_strategy , BaseShuffleSplit ):
178- raise ValueError ('Illegal resampling strategy: %s' %
179- self ._resampling_strategy )
180-
181- if self ._resampling_strategy in ['partial-cv' ,
182- 'partial-cv-iterative-fit' ,
183- ] \
184- and self ._ensemble_size != 0 :
185- raise ValueError ("Resampling strategy %s cannot be used "
186- "together with ensembles." % self ._resampling_strategy )
187- if self ._resampling_strategy in ['partial-cv' ,
188- 'cv' ,
189- 'cv-iterative-fit' ,
190- 'partial-cv-iterative-fit' ,
191- ]\
192- and 'folds' not in self ._resampling_strategy_arguments :
193- self ._resampling_strategy_arguments ['folds' ] = 5
194168 self ._n_jobs = n_jobs
195169 self ._dask_client = dask_client
196170
@@ -208,6 +182,7 @@ def __init__(self,
208182 "'disable_evaluator_output' must be one "
209183 "of " + str (allowed_elements ))
210184 self ._get_smac_object_callback = get_smac_object_callback
185+ self ._get_trials_callback = get_trials_callback
211186 self ._smac_scenario_args = smac_scenario_args
212187 self .logging_config = logging_config
213188
@@ -254,9 +229,6 @@ def __init__(self,
254229 # By default try to use the TCP logging port or get a new port
255230 self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
256231
257- # After assigning and checking variables...
258- # self._backend = Backend(self._output_dir, self._tmp_dir)
259-
260232 # Num_run tell us how many runs have been launched
261233 # It can be seen as an identifier for each configuration
262234 # saved to disk
@@ -427,7 +399,7 @@ def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int:
427399 self ._logger .error (
428400 "Dummy prediction failed with run state %s. "
429401 "The error suggests that the provided memory limits were too tight. Please "
430- "increase the 'ml_memory_limit ' and try again. If this does not solve your "
402+ "increase the 'memory_limit ' and try again. If this does not solve your "
431403 "problem, please open an issue and paste the additional output. "
432404 "Additional output: %s." ,
433405 str (status ), str (additional_info ),
@@ -436,7 +408,7 @@ def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int:
436408 raise ValueError (
437409 "Dummy prediction failed with run state %s. "
438410 "The error suggests that the provided memory limits were too tight. Please "
439- "increase the 'ml_memory_limit ' and try again. If this does not solve your "
411+ "increase the 'memory_limit ' and try again. If this does not solve your "
440412 "problem, please open an issue and paste the additional output. "
441413 "Additional output: %s." %
442414 (str (status ), str (additional_info )),
@@ -510,6 +482,15 @@ def fit(
510482 task = self ._task ,
511483 )
512484
485+ # Check the re-sampling strategy
486+ try :
487+ self ._check_resampling_strategy (
488+ X = X , y = y , task = task ,
489+ )
490+ except Exception as e :
491+ self ._fit_cleanup ()
492+ raise e
493+
513494 # Reset learnt stuff
514495 self .models_ = None
515496 self .cv_models_ = None
@@ -537,10 +518,8 @@ def fit(
537518 self ._dataset_name = dataset_name
538519 self ._stopwatch .start_task (self ._dataset_name )
539520
540- if feat_type is None and self .InputValidator .feature_validator .feat_type :
541- self ._feat_type = self .InputValidator .feature_validator .feat_type
542- elif feat_type is not None :
543- self ._feat_type = feat_type
521+ # Take the feature types from the validator
522+ self ._feat_type = self .InputValidator .feature_validator .feat_type
544523
545524 # Produce debug information to the logfile
546525 self ._logger .debug ('Starting to print environment information' )
@@ -573,7 +552,6 @@ def fit(
573552 raise ValueError ('Unable to read requirement: %s' % requirement )
574553 self ._logger .debug ('Done printing environment information' )
575554 self ._logger .debug ('Starting to print arguments to auto-sklearn' )
576- self ._logger .debug (' output_folder: %s' , self ._backend .context ._output_directory )
577555 self ._logger .debug (' tmp_folder: %s' , self ._backend .context ._temporary_directory )
578556 self ._logger .debug (' time_left_for_this_task: %f' , self ._time_for_task )
579557 self ._logger .debug (' per_run_time_limit: %f' , self ._per_run_time_limit )
@@ -782,6 +760,7 @@ def fit(
782760 port = self ._logger_port ,
783761 pynisher_context = self ._multiprocessing_context ,
784762 ensemble_callback = proc_ensemble ,
763+ trials_callback = self ._get_trials_callback
785764 )
786765
787766 try :
@@ -839,6 +818,63 @@ def _fit_cleanup(self):
839818 self ._clean_logger ()
840819 return
841820
821+ def _check_resampling_strategy (
822+ self ,
823+ X : SUPPORTED_FEAT_TYPES ,
824+ y : SUPPORTED_TARGET_TYPES ,
825+ task : int ,
826+ ) -> None :
827+ """
828+ This method centralizes the checks for resampling strategies
829+
830+ Parameters
831+ ----------
832+ X: (SUPPORTED_FEAT_TYPES)
833+ Input features for the given task
834+ y: (SUPPORTED_TARGET_TYPES)
835+ Input targets for the given task
836+ task: (task)
837+ Integer describing a supported task type, like BINARY_CLASSIFICATION
838+ """
839+ is_split_object = isinstance (
840+ self ._resampling_strategy ,
841+ (BaseCrossValidator , _RepeatedSplits , BaseShuffleSplit )
842+ )
843+
844+ if self ._resampling_strategy not in [
845+ 'holdout' ,
846+ 'holdout-iterative-fit' ,
847+ 'cv' ,
848+ 'cv-iterative-fit' ,
849+ 'partial-cv' ,
850+ 'partial-cv-iterative-fit' ,
851+ ] and not is_split_object :
852+ raise ValueError ('Illegal resampling strategy: %s' % self ._resampling_strategy )
853+
854+ elif is_split_object :
855+ TrainEvaluator .check_splitter_resampling_strategy (
856+ X = X , y = y , task = task ,
857+ groups = self ._resampling_strategy_arguments .get ('groups' , None ),
858+ resampling_strategy = self ._resampling_strategy ,
859+ )
860+
861+ elif self ._resampling_strategy in [
862+ 'partial-cv' ,
863+ 'partial-cv-iterative-fit' ,
864+ ] and self ._ensemble_size != 0 :
865+ raise ValueError ("Resampling strategy %s cannot be used "
866+ "together with ensembles." % self ._resampling_strategy )
867+
868+ elif self ._resampling_strategy in [
869+ 'partial-cv' ,
870+ 'cv' ,
871+ 'cv-iterative-fit' ,
872+ 'partial-cv-iterative-fit' ,
873+ ] and 'folds' not in self ._resampling_strategy_arguments :
874+ self ._resampling_strategy_arguments ['folds' ] = 5
875+
876+ return
877+
842878 @staticmethod
843879 def subsample_if_too_large (
844880 X : SUPPORTED_FEAT_TYPES ,
@@ -1022,7 +1058,7 @@ def fit_pipeline(
10221058 attributes will be automatically One-Hot encoded. The values
10231059 used for a categorical attribute must be integers, obtained for
10241060 example by `sklearn.preprocessing.LabelEncoder
1025- <http ://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_.
1061+ <https ://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_.
10261062
10271063 Returns
10281064 -------
0 commit comments