@@ -499,6 +499,8 @@ def fit(
499499 if X_test is not None :
500500 X_test , y_test = self .InputValidator .transform (X_test , y_test )
501501
502+ self ._task = task
503+
502504 X , y = self .subsample_if_too_large (
503505 X = X ,
504506 y = y ,
@@ -625,8 +627,6 @@ def fit(
625627 )
626628
627629 self ._backend ._make_internals_directory ()
628-
629- self ._task = datamanager .info ['task' ]
630630 self ._label_num = datamanager .info ['label_num' ]
631631
632632 # == Pickle the data manager to speed up loading
@@ -840,7 +840,14 @@ def _fit_cleanup(self):
840840 return
841841
842842 @staticmethod
843- def subsample_if_too_large (X , y , logger , seed , memory_limit , task ):
843+ def subsample_if_too_large (
844+ X : SUPPORTED_FEAT_TYPES ,
845+ y : SUPPORTED_TARGET_TYPES ,
846+ logger ,
847+ seed : int ,
848+ memory_limit : int ,
849+ task : int ,
850+ ):
844851 if memory_limit and isinstance (X , np .ndarray ):
845852 if X .dtype == np .float32 :
846853 multiplier = 4
@@ -884,12 +891,14 @@ def subsample_if_too_large(X, y, logger, seed, memory_limit, task):
884891 train_size = new_num_samples ,
885892 random_state = seed ,
886893 )
887- else :
894+ elif task in REGRESSION_TASKS :
888895 X , _ , y , _ = sklearn .model_selection .train_test_split (
889896 X , y ,
890897 train_size = new_num_samples ,
891898 random_state = seed ,
892899 )
900+ else :
901+ raise ValueError (task )
893902 return X , y
894903
895904 def refit (self , X , y ):
0 commit comments