Skip to content

Commit 8f0f223

Browse files
committed
Address PR comments
1 parent 635b7cc commit 8f0f223

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/mlforecast/evaluate_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def _get_lags(self, freq: str, median_series_len: int, seasonality: int = 1) ->
106106

107107
from autogluon.timeseries.utils.datetime import get_lags_for_frequency
108108

109+
# Limit max lag so that we have enough training samples even for short series.
110+
# After differencing, the effective length decreases, and we need some rows left
111+
# for training features; hence we reserve at least 10 rows for feature construction.
109112
diff_cost = max(self.differences) if self.differences else seasonality
110113
effective_len = median_series_len - diff_cost
111114
max_lag = min(effective_len - 1, max(1, effective_len - 10))
@@ -317,6 +320,9 @@ def _run_hpo(
317320

318321
optuna.logging.set_verbosity(optuna.logging.ERROR)
319322

323+
# MLForecast doesn't allow passing kwargs to model.fit(), so we use custom model wrappers
324+
# to inject time limit callbacks and specify categorical features. We also construct custom
325+
# search spaces since the default ones in MLForecast can lead to catastrophically bad performance.
320326
forecaster = AutoMLForecast(
321327
models={self.regressor: AutoModel(model=self._create_model(), config=lambda t: {})},
322328
freq=task.freq,

0 commit comments

Comments
 (0)