Skip to content

Commit ab100e1

Browse files
committed
Update model aliases
1 parent 731d3cd commit ab100e1

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

examples/mlforecast/evaluate_model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,24 @@ class MLForecastModel:
7979

8080
def __init__(
8181
self,
82-
regressor: Literal["gbm", "cat"] = "gbm",
82+
regressor: Literal["lightgbm", "catboost"] = "lightgbm",
8383
lags: list[int] | None = None,
8484
date_features: list | None = None,
8585
differences: list[int] | None = None,
8686
fit_time_limit: float | None = 600,
8787
model_kwargs: dict | None = None,
8888
):
89-
self.regressor = regressor.lower()
89+
self.regressor = regressor
9090
self.lags = lags
9191
self.date_features = date_features
9292
self.differences = differences
9393
self.fit_time_limit = fit_time_limit
9494
self.model_kwargs = model_kwargs or {}
9595

9696
def _create_model(self):
97-
if self.regressor == "gbm":
97+
if self.regressor == "lightgbm":
9898
return _create_lgbm(self.fit_time_limit, **self.model_kwargs)
99-
if self.regressor == "cat":
99+
if self.regressor == "catboost":
100100
return _create_catboost(self.fit_time_limit, **self.model_kwargs)
101101
raise ValueError(f"Unknown regressor: {self.regressor}")
102102

@@ -254,7 +254,7 @@ class MLForecastAutoModel(MLForecastModel):
254254

255255
def __init__(
256256
self,
257-
regressor: Literal["gbm", "cat"] = "gbm",
257+
regressor: Literal["lightgbm", "catboost"] = "lightgbm",
258258
num_samples: int = 20,
259259
n_windows: int = 3,
260260
hpo_time_limit: float | None = 1800,
@@ -403,19 +403,18 @@ def fit_predict(self, task: fev.Task) -> tuple[list[datasets.DatasetDict], float
403403
if __name__ == "__main__":
404404
# Configuration
405405
use_auto = True # Set to False for fixed preprocessing
406-
regressor = "gbm" # "gbm" or "cat"
406+
model_name = "lightgbm" # "lightgbm" or "catboost"
407407
num_tasks = None # Set to small number for testing, None for full benchmark
408408

409409
benchmark = fev.Benchmark.from_yaml(
410410
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/fev_bench/tasks.yaml"
411411
)
412412

413413
if use_auto:
414-
model = MLForecastAutoModel(regressor=regressor)
415-
model_name = f"mlforecast-{regressor}-auto"
414+
model = MLForecastAutoModel(regressor=model_name)
415+
model_name = f"auto{model_name}"
416416
else:
417-
model = MLForecastModel(regressor=regressor)
418-
model_name = f"mlforecast-{regressor}"
417+
model = MLForecastModel(regressor=model_name)
419418

420419
summaries = []
421420
for task in tqdm(benchmark.tasks[:num_tasks]):

0 commit comments

Comments
 (0)