Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,25 @@ def custom_metric(
mlflow_logging: boolean, default=True | Whether to log the training results to mlflow.
This requires mlflow to be installed and to have an active mlflow run.
FLAML will create nested runs.
mutioutput_train_size: int, default=None | For multi-output tasks, when `eval_method` is set to
"holdout" and a validation set is manually specified, set this parameter to the length of
the training set. When calling the `fit` method, concatenate the training set and the validation set.
e.g.,

```python
model = MultiOutputRegressor(
AutoML(
task="regression",
time_budget=1,
eval_method="holdout",
multioutput_train_size=len(X_train)
)
)
model.fit(
pd.concat([X_train, X_val]),
pd.concat([y_train, y_val])
)
```

"""
if ERROR:
Expand Down Expand Up @@ -375,6 +394,7 @@ def custom_metric(
settings["custom_hp"] = settings.get("custom_hp", {})
settings["skip_transform"] = settings.get("skip_transform", False)
settings["mlflow_logging"] = settings.get("mlflow_logging", True)
settings["multioutput_train_size"] = settings.get("multioutput_train_size", None)

self._estimator_type = "classifier" if settings["task"] in CLASSIFICATION else "regressor"

Expand Down Expand Up @@ -1148,6 +1168,9 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
)
self.data_size_full = self._state.data_size_full

def _train_val_split(self, train_val_concat, multioutput_train_size):
return train_val_concat[:multioutput_train_size], train_val_concat[multioutput_train_size:]

def fit(
self,
X_train=None,
Expand Down Expand Up @@ -1524,6 +1547,10 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):

self._state._start_time_flag = self._start_time_flag = time.time()
task = task or self._settings.get("task")
multioutput_train_size = self._settings.get("multioutput_train_size")
if multioutput_train_size is not None:
X_train, X_val = self._train_val_split(X_train, multioutput_train_size)
y_train, y_val = self._train_val_split(y_train, multioutput_train_size)
if isinstance(task, str):
task = task_factory(task, X_train, y_train)
self._state.task = task
Expand Down
23 changes: 23 additions & 0 deletions test/automl/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,5 +230,28 @@ def test_multioutput():
print(model.predict(X_test))


def test_multioutput_train_size():
import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputRegressor, RegressorChain

# create regression data
X, y = make_regression(n_targets=3)

# split into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

# train the model
model = MultiOutputRegressor(
AutoML(task="regression", time_budget=1, eval_method="holdout", multioutput_train_size=len(X_train))
)
model.fit(np.concatenate([X_train, X_val], axis=0), np.concatenate([y_train, y_val], axis=0))

# predict
print(model.predict(X_test))
Comment on lines 233 to 253
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test function lacks assertions to verify the new multioutput_train_size feature works as expected. Consider adding assertions to validate that the model was trained successfully and that the validation split was performed correctly. For example, you could check that the model produces reasonable predictions or verify internal state that confirms the train/validation split occurred.

Copilot uses AI. Check for mistakes.


if __name__ == "__main__":
unittest.main()
Loading