Skip to content
Merged
6 changes: 6 additions & 0 deletions src/chronos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def predict_df(
prediction_length: int | None = None,
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
Expand All @@ -166,6 +167,10 @@ def predict_df(
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead.
**predict_kwargs
Additional arguments passed to predict_quantiles

Expand Down Expand Up @@ -200,6 +205,7 @@ def predict_df(
timestamp_column=timestamp_column,
target_columns=[target],
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

Expand Down
7 changes: 7 additions & 0 deletions src/chronos/chronos2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def predict_df(
context_length: int | None = None,
cross_learning: bool = False,
validate_inputs: bool = True,
freq: str | None = None,
**predict_kwargs,
) -> "pd.DataFrame":
"""
Expand Down Expand Up @@ -866,6 +867,11 @@ def predict_df(
validate_inputs
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
freq
Frequency string for timestamp generation (e.g., "h", "D", "W"). If provided, this frequency is used
instead of inferring it from the data. This is useful when you already know the frequency and want to
skip the inference overhead. Only used when future_df is not provided, since timestamps are extracted
from future_df when it's available.
**predict_kwargs
Additional arguments passed to predict_quantiles

Expand Down Expand Up @@ -896,6 +902,7 @@ def predict_df(
timestamp_column=timestamp_column,
target_columns=target,
prediction_length=prediction_length,
freq=freq,
validate_inputs=validate_inputs,
)

Expand Down
68 changes: 38 additions & 30 deletions src/chronos/df_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def convert_df_input_to_list_of_dicts_input(
prediction_length: int,
id_column: str = "item_id",
timestamp_column: str = "timestamp",
freq: str | None = None,
validate_inputs: bool = True,
) -> tuple[list[dict[str, np.ndarray | dict[str, np.ndarray]]], np.ndarray, dict[str, "pd.DatetimeIndex"]]:
"""
Expand All @@ -229,8 +230,12 @@ def convert_df_input_to_list_of_dicts_input(
Name of column containing time series identifiers
timestamp_column
Name of column containing timestamps
freq
Frequency string for timestamp generation. If provided, this frequency is used
instead of inferring it from the data. Only used when future_df is not provided,
since timestamps are extracted from future_df when it's available.
validate_inputs
When True, the dataframe(s) will be validated be conversion
When True, the dataframe(s) will be validated before conversion

Returns
-------
Expand All @@ -243,71 +248,74 @@ def convert_df_input_to_list_of_dicts_input(
import pandas as pd

if validate_inputs:
df, future_df, freq, series_lengths, original_order = validate_df_inputs(
df, future_df, inferred_freq, series_lengths, original_order = validate_df_inputs(
df,
future_df=future_df,
id_column=id_column,
timestamp_column=timestamp_column,
target_columns=target_columns,
prediction_length=prediction_length,
)
# Use provided freq if available, otherwise use inferred freq
if freq is None:
freq = inferred_freq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If freq is not None, should we verify that inferred_freq matches it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand, the whole point is to skip running the frequency inference when freq is provided. Should we still do it?

else:
# Get the original order of time series IDs
original_order = df[id_column].unique()

# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()

# If validation is skipped, the first freq in the dataframe is used
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
freq = None
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break
# If freq is not provided, infer from the first series with >= 3 points
if freq is None:
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
start_idx = 0
for length in series_lengths:
if length < 3:
start_idx += length
continue
timestamps = timestamp_index[start_idx : start_idx + length]
freq = pd.infer_freq(timestamps)
break

assert freq is not None, "validate is False, but could not infer frequency from the dataframe"
assert freq is not None, "validate_inputs is False, but could not infer frequency from the dataframe"

# Convert to list of dicts format
inputs: list[dict[str, np.ndarray | dict[str, np.ndarray]]] = []
prediction_timestamps: dict[str, pd.DatetimeIndex] = {}

indptr = np.concatenate([[0], np.cumsum(series_lengths)]).astype("int64")
target_array = df[target_columns].to_numpy().T # Shape: (n_targets, len(df))
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_array = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)

past_covariates_dict = {
col: df[col].to_numpy() for col in df.columns if col not in [id_column, timestamp_column] + target_columns
}
future_covariates_dict = {}

if future_df is not None:
# Use timestamps from future_df
prediction_timestamps_flat = pd.DatetimeIndex(future_df[timestamp_column])
for col in future_df.columns.drop([id_column, timestamp_column]):
future_covariates_dict[col] = future_df[col].to_numpy()
if validate_inputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about removing this. This was nice check for incorrect slicing of future data. I am wondering if we should only allow freq with validate_inputs=False? I feel like freq in intended really for cases where you know what you are doing. For general use, automatic inference is probably the way to go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I've updated the PR to only allow freq: str when validate_inputs=False. The logic should also be much simpler in this case

if (pd.DatetimeIndex(future_df[timestamp_column]) != pd.DatetimeIndex(prediction_timestamps_array)).any():
raise ValueError(
"future_df timestamps do not match the expected prediction timestamps. "
"You can disable this check by setting `validate_inputs=False`"
)
else:
# Generate timestamps from freq
assert freq is not None, "freq must be provided or inferred when future_df is not provided"
last_ts = pd.DatetimeIndex(df[timestamp_column].iloc[indptr[1:] - 1]) # Shape: (n_series,)
offset = pd.tseries.frequencies.to_offset(freq)
with warnings.catch_warnings():
# Silence PerformanceWarning for non-vectorized offsets https://github.com/pandas-dev/pandas/blob/95624ca2e99b0/pandas/core/arrays/datetimes.py#L822
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
# Generate all prediction timestamps at once by stacking offsets into shape (n_series * prediction_length)
prediction_timestamps_flat = pd.DatetimeIndex(
np.dstack([last_ts + step * offset for step in range(1, prediction_length + 1)]).ravel()
)

for i in range(len(series_lengths)):
start_idx, end_idx = indptr[i], indptr[i + 1]
future_start_idx, future_end_idx = i * prediction_length, (i + 1) * prediction_length

series_id = df[id_column].iloc[start_idx]
prediction_timestamps[series_id] = prediction_timestamps_array[future_start_idx:future_end_idx]
prediction_timestamps[series_id] = prediction_timestamps_flat[future_start_idx:future_end_idx]
task: dict[str, np.ndarray | dict[str, np.ndarray]] = {"target": target_array[:, start_idx:end_idx]}

if len(past_covariates_dict) > 0:
Expand Down
65 changes: 65 additions & 0 deletions test/test_df_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,68 @@ def test_convert_df_preserves_all_values_with_random_inputs():
assert len(inputs) == n_series
assert list(original_order) == series_ids
assert len(prediction_timestamps) == n_series


# Tests for freq parameter


@pytest.mark.parametrize("validate_inputs", [True, False])
@pytest.mark.parametrize("use_future_df", [True, False])
def test_convert_df_with_provided_freq(validate_inputs, use_future_df):
"""Test that provided freq works with different combinations of validate_inputs and future_df."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")
prediction_length = 5

future_df = None
if use_future_df:
forecast_start_times = get_forecast_start_times(df, freq="h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[prediction_length, prediction_length],
covariates=["cov1"],
freq="h"
)

inputs, original_order, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=prediction_length,
freq="h",
validate_inputs=validate_inputs,
)

assert len(inputs) == 2
assert len(prediction_timestamps) == 2
for series_id in ["A", "B"]:
assert len(prediction_timestamps[series_id]) == prediction_length


def test_convert_df_with_future_df_uses_future_df_timestamps():
"""Test that timestamps from future_df are used when future_df is provided."""
df = create_df(series_ids=["A", "B"], n_points=[10, 12], target_cols=["target"], covariates=["cov1"], freq="h")

# Create future_df with 2h freq (different from df's 1h freq)
forecast_start_times = get_forecast_start_times(df, freq="2h")
future_df = create_future_df(
forecast_start_times=forecast_start_times,
series_ids=["A", "B"],
n_points=[5, 5],
covariates=["cov1"],
freq="2h"
)

inputs, _, prediction_timestamps = convert_df_input_to_list_of_dicts_input(
df=df,
future_df=future_df,
target_columns=["target"],
prediction_length=5,
validate_inputs=False,
)

# Verify timestamps come from future_df (2h spacing)
future_df_sorted = future_df.sort_values(["item_id", "timestamp"])
for series_id in ["A", "B"]:
expected_timestamps = pd.DatetimeIndex(future_df_sorted[future_df_sorted["item_id"] == series_id]["timestamp"])
pd.testing.assert_index_equal(prediction_timestamps[series_id], expected_timestamps)
Loading