Skip to content

Commit 79627e1

Browse files
committed
FIX: correctly pass task type to data subsampling
1 parent 802c337 commit 79627e1

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

autosklearn/automl.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ def fit(
499499
if X_test is not None:
500500
X_test, y_test = self.InputValidator.transform(X_test, y_test)
501501

502+
self._task = task
503+
502504
X, y = self.subsample_if_too_large(
503505
X=X,
504506
y=y,
@@ -625,8 +627,6 @@ def fit(
625627
)
626628

627629
self._backend._make_internals_directory()
628-
629-
self._task = datamanager.info['task']
630630
self._label_num = datamanager.info['label_num']
631631

632632
# == Pickle the data manager to speed up loading
@@ -840,7 +840,14 @@ def _fit_cleanup(self):
840840
return
841841

842842
@staticmethod
843-
def subsample_if_too_large(X, y, logger, seed, memory_limit, task):
843+
def subsample_if_too_large(
844+
X: SUPPORTED_FEAT_TYPES,
845+
y: SUPPORTED_TARGET_TYPES,
846+
logger,
847+
seed: int,
848+
memory_limit: int,
849+
task: int,
850+
):
844851
if memory_limit and isinstance(X, np.ndarray):
845852
if X.dtype == np.float32:
846853
multiplier = 4
@@ -884,12 +891,14 @@ def subsample_if_too_large(X, y, logger, seed, memory_limit, task):
884891
train_size=new_num_samples,
885892
random_state=seed,
886893
)
887-
else:
894+
elif task in REGRESSION_TASKS:
888895
X, _, y, _ = sklearn.model_selection.train_test_split(
889896
X, y,
890897
train_size=new_num_samples,
891898
random_state=seed,
892899
)
900+
else:
901+
raise ValueError(task)
893902
return X, y
894903

895904
def refit(self, X, y):

0 commit comments

Comments
 (0)