Skip to content

Commit 1f099eb

Browse files
authored
Allow explicitly passing the frequency to predict_df (#449)
*Issue #, if available:* #425 *Description of changes:* - Add `freq: str | None` parameter to `predict_df` methods. This can only be set in combination with `validate_inputs=False`. If specified, the user-provided `freq` will be used instead of the tryin to infer the `freq` from the data. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent f889ae6 commit 1f099eb

File tree

4 files changed

+214
-79
lines changed

4 files changed

+214
-79
lines changed

src/chronos/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def predict_df(
142142
prediction_length: int | None = None,
143143
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
144144
validate_inputs: bool = True,
145+
freq: str | None = None,
145146
**predict_kwargs,
146147
) -> "pd.DataFrame":
147148
"""
@@ -164,8 +165,14 @@ def predict_df(
164165
quantile_levels
165166
Quantile levels to compute
166167
validate_inputs
167-
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
168-
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
168+
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
169+
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
170+
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
171+
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
172+
timestamps are regularly spaced (e.g., with hourly frequency).
173+
freq
174+
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
175+
validate_inputs=False. When provided, skips frequency inference from the data.
169176
**predict_kwargs
170177
Additional arguments passed to predict_quantiles
171178
@@ -200,6 +207,7 @@ def predict_df(
200207
timestamp_column=timestamp_column,
201208
target_columns=[target],
202209
prediction_length=prediction_length,
210+
freq=freq,
203211
validate_inputs=validate_inputs,
204212
)
205213

src/chronos/chronos2/pipeline.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ def predict_df(
825825
context_length: int | None = None,
826826
cross_learning: bool = False,
827827
validate_inputs: bool = True,
828+
freq: str | None = None,
828829
**predict_kwargs,
829830
) -> "pd.DataFrame":
830831
"""
@@ -864,8 +865,14 @@ def predict_df(
864865
For optimal results, consider using a batch size around 100 (as used in the Chronos-2 technical report).
865866
- Cross-learning is most helpful when individual time series have limited historical context, as the model can leverage patterns from related series in the batch.
866867
validate_inputs
867-
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
868-
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
868+
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
869+
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
870+
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
871+
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
872+
timestamps are regularly spaced (e.g., with hourly frequency).
873+
freq
874+
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used when
875+
validate_inputs=False. When provided, skips frequency inference from the data.
869876
**predict_kwargs
870877
Additional arguments passed to predict_quantiles
871878
@@ -896,6 +903,7 @@ def predict_df(
896903
timestamp_column=timestamp_column,
897904
target_columns=target,
898905
prediction_length=prediction_length,
906+
freq=freq,
899907
validate_inputs=validate_inputs,
900908
)
901909

src/chronos/df_utils.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def convert_df_input_to_list_of_dicts_input(
204204
id_column: str = "item_id",
205205
timestamp_column: str = "timestamp",
206206
validate_inputs: bool = True,
207+
freq: str | None = None,
207208
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
208209
"""
209210
Convert from dataframe input format to a list of dictionaries input format.
@@ -230,7 +231,14 @@ def convert_df_input_to_list_of_dicts_input(
230231
timestamp_column
231232
Name of column containing timestamps
232233
validate_inputs
233-
When True, the dataframe(s) will be validated be conversion
234+
[ADVANCED] When True (default), validates dataframes before prediction. Setting to False removes the
235+
validation overhead, but may silently lead to wrong predictions if data is misformatted. When False, you
236+
must ensure: (1) all dataframes are sorted by (id_column, timestamp_column); (2) future_df (if provided)
237+
has the same item IDs as df with exactly prediction_length rows of future timestamps per item; (3) all
238+
timestamps are regularly spaced (e.g., with hourly frequency).
239+
freq
240+
Frequency string for timestamp generation (e.g., "h", "D", "W"). Can only be used
241+
when validate_inputs=False. When provided, skips frequency inference from the data.
234242
235243
Returns
236244
-------
@@ -242,6 +250,16 @@ def convert_df_input_to_list_of_dicts_input(
242250

243251
import pandas as pd
244252

253+
if freq is not None and validate_inputs:
254+
raise ValueError(
255+
"freq can only be provided when validate_inputs=False. "
256+
"When using freq with validate_inputs=False, you must ensure: "
257+
"(1) all dataframes are sorted by (id_column, timestamp_column); "
258+
"(2) future_df (if provided) has the same item IDs as df with exactly "
259+
"prediction_length rows of future timestamps per item; "
260+
"(3) all timestamps are regularly spaced."
261+
)
262+
245263
if validate_inputs:
246264
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
247265
df,
@@ -258,19 +276,19 @@ def convert_df_input_to_list_of_dicts_input(
258276
# Get series lengths
259277
series_lengths = df[id_column].value_counts(sort=False).to_list()
260278

261-
# If validation is skipped, the first freq in the dataframe is used
262-
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
263-
start_idx = 0
264-
freq = None
265-
for length in series_lengths:
266-
if length < 3:
267-
start_idx += length
268-
continue
269-
timestamps = timestamp_index[start_idx : start_idx + length]
270-
freq = pd.infer_freq(timestamps)
271-
break
272-
273-
assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
279+
# If freq is not provided, infer from the first series with >= 3 points
280+
if freq is None:
281+
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
282+
start_idx = 0
283+
for length in series_lengths:
284+
if length < 3:
285+
start_idx += length
286+
continue
287+
timestamps = timestamp_index[start_idx : start_idx + length]
288+
freq = pd.infer_freq(timestamps)
289+
break
290+
291+
assert freq is not None, "validate_inputs is False, but could not infer frequency from the dataframe"
274292

275293
# Convert to list of dicts format
276294
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []

0 commit comments

Comments
 (0)