Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def custom_metric(
"pred_time": pred_time,
}
```
**Note:** When passing a custom metric function, pass the function itself
(e.g., `metric=custom_metric`), not the result of calling it
(e.g., `metric=custom_metric(...)`). FLAML will call your function
internally during the training process.
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast', 'rank',
'seq-classification', 'seq-regression', 'summarization',
Expand Down Expand Up @@ -370,6 +374,8 @@ def custom_metric(
settings["n_splits"] = settings.get("n_splits", N_SPLITS)
settings["auto_augment"] = settings.get("auto_augment", True)
settings["metric"] = settings.get("metric", "auto")
# Validate that custom metric is callable if not a string
self._validate_metric_parameter(settings["metric"], allow_auto=True)
settings["estimator_list"] = settings.get("estimator_list", "auto")
settings["log_file_name"] = settings.get("log_file_name", "")
settings["max_iter"] = settings.get("max_iter") # no budget by default
Expand Down Expand Up @@ -462,6 +468,28 @@ def __setstate__(self, state):
except Exception:
mi.mlflow_client = None

@staticmethod
def _validate_metric_parameter(metric, allow_auto=True):
"""Validate that the metric parameter is either a string or a callable function.

Args:
metric: The metric parameter to validate.
allow_auto: Whether to allow "auto" as a valid string value.

Raises:
ValueError: If metric is not a string or callable function.
"""
if allow_auto and metric == "auto":
return
if not isinstance(metric, str) and not callable(metric):
raise ValueError(
f"The 'metric' parameter must be either a string or a callable function, "
f"but got {type(metric).__name__}. "
f"If you defined a custom_metric function, make sure to pass the function itself "
f"(e.g., metric=custom_metric) and not the result of calling it "
f"(e.g., metric=custom_metric(...))."
)

def get_params(self, deep: bool = False) -> dict:
return self._settings.copy()

Expand Down Expand Up @@ -1810,6 +1838,10 @@ def custom_metric(
"pred_time": pred_time,
}
```
**Note:** When passing a custom metric function, pass the function itself
(e.g., `metric=custom_metric`), not the result of calling it
(e.g., `metric=custom_metric(...)`). FLAML will call your function
internally during the training process.
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast_regression',
'ts_forecast_classification', 'rank', 'seq-classification',
Expand Down Expand Up @@ -2095,7 +2127,7 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
split_ratio = split_ratio or self._settings.get("split_ratio")
n_splits = n_splits or self._settings.get("n_splits")
auto_augment = self._settings.get("auto_augment") if auto_augment is None else auto_augment
metric = metric or self._settings.get("metric")
metric = self._settings.get("metric") if metric is None else metric
estimator_list = estimator_list or self._settings.get("estimator_list")
log_file_name = self._settings.get("log_file_name") if log_file_name is None else log_file_name
max_iter = self._settings.get("max_iter") if max_iter is None else max_iter
Expand Down Expand Up @@ -2334,6 +2366,9 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
and (self._min_sample_size * SAMPLE_MULTIPLY_FACTOR < self._state.data_size[0])
)

# Validate metric parameter before processing
self._validate_metric_parameter(metric, allow_auto=True)

metric = task.default_metric(metric)
self._state.metric = metric

Expand Down
26 changes: 26 additions & 0 deletions test/automl/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,32 @@ def test_custom_metric(self):
except ImportError:
pass

def test_invalid_custom_metric(self):
"""Test that proper error is raised when custom_metric is called instead of passed."""
from sklearn.datasets import load_iris

X_train, y_train = load_iris(return_X_y=True)

# Test with non-callable metric in __init__
with self.assertRaises(ValueError) as context:
automl = AutoML(metric=123) # passing an int instead of function
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got int", str(context.exception))

# Test with non-callable metric in fit
automl = AutoML()
with self.assertRaises(ValueError) as context:
automl.fit(X_train=X_train, y_train=y_train, metric=[], task="classification", time_budget=1)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got list", str(context.exception))

# Test with tuple (simulating result of calling a function that returns tuple)
with self.assertRaises(ValueError) as context:
automl = AutoML()
automl.fit(X_train=X_train, y_train=y_train, metric=(0.5, {"loss": 0.5}), task="classification", time_budget=1)
self.assertIn("must be either a string or a callable function", str(context.exception))
self.assertIn("but got tuple", str(context.exception))

def test_classification(self, as_frame=False):
automl_experiment = AutoML()
automl_settings = {
Expand Down
Loading