4343from autosklearn .util .backend import Backend
4444from autosklearn .util .stopwatch import StopWatch
4545from autosklearn .util .logging_ import (
46- get_logger ,
4746 setup_logger ,
4847 start_log_server ,
48+ get_named_client_logger ,
4949)
5050from autosklearn .util import pipeline , RE_PATTERN
5151from autosklearn .ensemble_builder import EnsembleBuilderManager
5454from autosklearn .util .hash import hash_array_or_matrix
5555from autosklearn .metrics import f1_macro , accuracy , r2
5656from autosklearn .constants import MULTILABEL_CLASSIFICATION , MULTICLASS_CLASSIFICATION , \
57- REGRESSION_TASKS , REGRESSION , BINARY_CLASSIFICATION , MULTIOUTPUT_REGRESSION
57+ REGRESSION_TASKS , REGRESSION , BINARY_CLASSIFICATION , MULTIOUTPUT_REGRESSION , \
58+ CLASSIFICATION_TASKS
5859from autosklearn .pipeline .components .classification import ClassifierChoice
5960from autosklearn .pipeline .components .regression import RegressorChoice
6061from autosklearn .pipeline .components .feature_preprocessing import FeaturePreprocessorChoice
@@ -228,6 +229,9 @@ def __init__(self,
228229 raise ValueError ("per_run_time_limit not of type integer, but %s" %
229230 str (type (self ._per_run_time_limit )))
230231
232+ # By default try to use the TCP logging port or get a new port
233+ self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
234+
231235 # After assigning and checking variables...
232236 # self._backend = Backend(self._output_dir, self._tmp_dir)
233237
@@ -313,7 +317,11 @@ def _get_logger(self, name):
313317
314318 self ._logger_port = int (port .value )
315319
316- return get_logger (logger_name )
320+ return get_named_client_logger (
321+ name = logger_name ,
322+ host = 'localhost' ,
323+ port = self ._logger_port ,
324+ )
317325
318326 def _clean_logger (self ):
319327 if not hasattr (self , 'stop_logging_server' ) or self .stop_logging_server is None :
@@ -380,6 +388,7 @@ def _do_dummy_prediction(self, datamanager, num_run):
380388 disable_file_output = self ._disable_evaluator_output ,
381389 abort_on_first_run_crash = False ,
382390 cost_for_crash = get_cost_of_crash (self ._metric ),
391+ port = self ._logger_port ,
383392 ** self ._resampling_strategy_arguments )
384393
385394 status , cost , runtime , additional_info = ta .run (num_run , cutoff = self ._time_for_task )
@@ -428,6 +437,12 @@ def fit(
428437 only_return_configuration_space : Optional [bool ] = False ,
429438 load_models : bool = True ,
430439 ):
440+ if dataset_name is None :
441+ dataset_name = hash_array_or_matrix (X )
442+ # The first thing we have to do is create the logger to update the backend
443+ self ._logger = self ._get_logger (dataset_name )
444+ self ._backend .setup_logger (self ._logger_port )
445+
431446 self ._backend .save_start_time (self ._seed )
432447 self ._stopwatch = StopWatch ()
433448
@@ -445,6 +460,15 @@ def fit(
445460 raise ValueError ('Target value shapes do not match: %s vs %s'
446461 % (y .shape , y_test .shape ))
447462
463+ X , y = self .subsample_if_too_large (
464+ X = X ,
465+ y = y ,
466+ logger = self ._logger ,
467+ seed = self ._seed ,
468+ memory_limit = self ._memory_limit ,
469+ task = self ._task ,
470+ )
471+
448472 # Reset learnt stuff
449473 self .models_ = None
450474 self .cv_models_ = None
@@ -459,12 +483,6 @@ def fit(
459483 raise ValueError ('Metric must be instance of '
460484 'autosklearn.metrics.Scorer.' )
461485
462- if dataset_name is None :
463- dataset_name = hash_array_or_matrix (X )
464- # By default try to use the TCP logging port or get a new port
465- self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
466- self ._logger = self ._get_logger (dataset_name )
467-
468486 # If no dask client was provided, we create one, so that we can
469487 # start a ensemble process in parallel to smbo optimize
470488 if (
@@ -718,6 +736,7 @@ def fit(
718736 get_smac_object_callback = self ._get_smac_object_callback ,
719737 smac_scenario_args = self ._smac_scenario_args ,
720738 scoring_functions = self ._scoring_functions ,
739+ port = self ._logger_port ,
721740 ensemble_callback = proc_ensemble ,
722741 )
723742
@@ -770,6 +789,59 @@ def fit(
770789
771790 return self
772791
792+ @staticmethod
793+ def subsample_if_too_large (X , y , logger , seed , memory_limit , task ):
794+ if isinstance (X , np .ndarray ):
795+ if X .dtype == np .float32 :
796+ multiplier = 4
797+ elif X .dtype in (np .float64 , np .float ):
798+ multiplier = 8
799+ elif X .dtype == np .float128 :
800+ multiplier = 16
801+ else :
802+ # Just assuming some value - very unlikely
803+ multiplier = 8
804+ logger .warning ('Unknown dtype for X: %s, assuming it takes 8 bit/number' ,
805+ str (X .dtype ))
806+ megabytes = X .shape [0 ] * X .shape [1 ] * multiplier / 1024 / 1024
807+ if memory_limit <= megabytes * 10 :
808+ new_num_samples = int (
809+ memory_limit / (10 * X .shape [1 ] * multiplier / 1024 / 1024 )
810+ )
811+ logger .warning (
812+ 'Dataset too large for memory limit %dMB, reducing number of samples from '
813+ '%d to %d.' ,
814+ memory_limit ,
815+ X .shape [0 ],
816+ new_num_samples ,
817+ )
818+ if task in CLASSIFICATION_TASKS :
819+ try :
820+ X , _ , y , _ = sklearn .model_selection .train_test_split (
821+ X , y ,
822+ train_size = new_num_samples ,
823+ random_state = seed ,
824+ stratify = y ,
825+ )
826+ except Exception :
827+ logger .warning (
828+ 'Could not sample dataset in stratified manner, resorting to random '
829+ 'sampling' ,
830+ exc_info = True
831+ )
832+ X , _ , y , _ = sklearn .model_selection .train_test_split (
833+ X , y ,
834+ train_size = new_num_samples ,
835+ random_state = seed ,
836+ )
837+ else :
838+ X , _ , y , _ = sklearn .model_selection .train_test_split (
839+ X , y ,
840+ train_size = new_num_samples ,
841+ random_state = seed ,
842+ )
843+ return X , y
844+
773845 def refit (self , X , y ):
774846
775847 # Make sure input data is valid
@@ -1118,9 +1190,9 @@ def cv_results_(self):
11181190 status .append ('Abort' )
11191191 elif s == StatusType .MEMOUT :
11201192 status .append ('Memout' )
1121- elif s == StatusType .RUNNING :
1122- continue
1123- elif s == StatusType .BUDGETEXHAUSTED :
1193+ # TODO remove StatusType.RUNNING at some point in the future when the new SMAC 0.13.2
1194+ # is the new minimum required version!
1195+ elif s in ( StatusType .STOP , StatusType . RUNNING ) :
11241196 continue
11251197 else :
11261198 raise NotImplementedError (s )
0 commit comments