Skip to content

Commit 340a860

Browse files
authored
[AINode] Integrate Chronos2 as builtin forecasting model (#16903)
1 parent 029fbed commit 340a860

File tree

12 files changed

+3962
-1
lines changed

12 files changed

+3962
-1
lines changed

LICENSE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,4 +339,14 @@ LMax Disruptor is open source software licensed under the Apache License 2.0 and
339339
Project page: https://github.com/LMAX-Exchange/disruptor
340340
License: https://github.com/LMAX-Exchange/disruptor/blob/master/LICENCE.txt
341341

342+
--------------------------------------------------------------------------------
343+
344+
The following files include code modified from chronos-forecasting project.
345+
346+
./iotdb-core/ainode/iotdb/ainode/core/model/chronos2/*
347+
348+
The chronos-forecasting is open source software licensed under the Apache License 2.0
349+
Project page: https://github.com/amazon-science/chronos-forecasting
350+
License: https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE
351+
342352
--------------------------------------------------------------------------------

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ public class AINodeTestUtils {
4949
new AbstractMap.SimpleEntry<>(
5050
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
5151
new AbstractMap.SimpleEntry<>(
52-
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")))
52+
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
53+
new AbstractMap.SimpleEntry<>(
54+
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")))
5355
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
5456

5557
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from enum import Enum
20+
from pathlib import Path
21+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
22+
23+
import torch
24+
25+
if TYPE_CHECKING:
26+
import pandas as pd
27+
from transformers import PreTrainedModel
28+
29+
30+
from iotdb.ainode.core.model.chronos2.utils import left_pad_and_stack_1D
31+
32+
33+
class ForecastType(Enum):
34+
SAMPLES = "samples"
35+
QUANTILES = "quantiles"
36+
37+
38+
class PipelineRegistry(type):
39+
REGISTRY: Dict[str, "PipelineRegistry"] = {}
40+
41+
def __new__(cls, name, bases, attrs):
42+
"""See, https://github.com/faif/python-patterns."""
43+
new_cls = type.__new__(cls, name, bases, attrs)
44+
if name is not None:
45+
cls.REGISTRY[name] = new_cls
46+
47+
return new_cls
48+
49+
50+
class BaseChronosPipeline(metaclass=PipelineRegistry):
51+
forecast_type: ForecastType
52+
dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32}
53+
54+
def __init__(self, inner_model: "PreTrainedModel"):
55+
"""
56+
Parameters
57+
----------
58+
inner_model : PreTrainedModel
59+
A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration
60+
"""
61+
# for easy access to the inner HF-style model
62+
self.inner_model = inner_model
63+
64+
@property
65+
def model_context_length(self) -> int:
66+
raise NotImplementedError()
67+
68+
@property
69+
def model_prediction_length(self) -> int:
70+
raise NotImplementedError()
71+
72+
def _prepare_and_validate_context(
73+
self, context: Union[torch.Tensor, List[torch.Tensor]]
74+
):
75+
if isinstance(context, list):
76+
context = left_pad_and_stack_1D(context)
77+
assert isinstance(context, torch.Tensor)
78+
if context.ndim == 1:
79+
context = context.unsqueeze(0)
80+
assert context.ndim == 2
81+
82+
return context
83+
84+
def predict(
85+
self,
86+
inputs: Union[torch.Tensor, List[torch.Tensor]],
87+
prediction_length: Optional[int] = None,
88+
):
89+
"""
90+
Get forecasts for the given time series. Predictions will be
91+
returned in fp32 on the cpu.
92+
93+
Parameters
94+
----------
95+
inputs
96+
Input series. This is either a 1D tensor, or a list
97+
of 1D tensors, or a 2D tensor whose first dimension
98+
is batch. In the latter case, use left-padding with
99+
``torch.nan`` to align series of different lengths.
100+
prediction_length
101+
Time steps to predict. Defaults to a model-dependent
102+
value if not given.
103+
104+
Returns
105+
-------
106+
forecasts
107+
Tensor containing forecasts. The layout and meaning
108+
of the forecasts values depends on ``self.forecast_type``.
109+
"""
110+
raise NotImplementedError()
111+
112+
def predict_quantiles(
113+
self,
114+
inputs: Union[torch.Tensor, List[torch.Tensor]],
115+
prediction_length: Optional[int] = None,
116+
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
117+
**kwargs,
118+
) -> Tuple[torch.Tensor, torch.Tensor]:
119+
"""
120+
Get quantile and mean forecasts for given time series.
121+
Predictions will be returned in fp32 on the cpu.
122+
123+
Parameters
124+
----------
125+
inputs : Union[torch.Tensor, List[torch.Tensor]]
126+
Input series. This is either a 1D tensor, or a list
127+
of 1D tensors, or a 2D tensor whose first dimension
128+
is batch. In the latter case, use left-padding with
129+
``torch.nan`` to align series of different lengths.
130+
prediction_length : Optional[int], optional
131+
Time steps to predict. Defaults to a model-dependent
132+
value if not given.
133+
quantile_levels : List[float], optional
134+
Quantile levels to compute, by default [0.1, 0.2, ..., 0.9]
135+
136+
Returns
137+
-------
138+
quantiles
139+
Tensor containing quantile forecasts. Shape
140+
(batch_size, prediction_length, num_quantiles)
141+
mean
142+
Tensor containing mean (point) forecasts. Shape
143+
(batch_size, prediction_length)
144+
"""
145+
raise NotImplementedError()
146+
147+
def predict_df(
148+
self,
149+
df: "pd.DataFrame",
150+
*,
151+
id_column: str = "item_id",
152+
timestamp_column: str = "timestamp",
153+
target: str = "target",
154+
prediction_length: int | None = None,
155+
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
156+
validate_inputs: bool = True,
157+
**predict_kwargs,
158+
) -> "pd.DataFrame":
159+
"""
160+
Perform forecasting on time series data in a long-format pandas DataFrame.
161+
162+
Parameters
163+
----------
164+
df
165+
Time series data in long format with an id column, a timestamp, and one target column.
166+
Any other columns, if present, will be ignored
167+
id_column
168+
The name of the column which contains the unique time series identifiers, by default "item_id"
169+
timestamp_column
170+
The name of the column which contains timestamps, by default "timestamp"
171+
All time series in the dataframe must have regular timestamps with the same frequency (no gaps)
172+
target
173+
The name of the column which contains the target variables to be forecasted, by default "target"
174+
prediction_length
175+
Number of steps to predict for each time series
176+
quantile_levels
177+
Quantile levels to compute
178+
validate_inputs
179+
When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a
180+
regular frequency, and item IDs match between past and future data. Setting to False disables these checks.
181+
**predict_kwargs
182+
Additional arguments passed to predict_quantiles
183+
184+
Returns
185+
-------
186+
The forecasts dataframe generated by the model with the following columns
187+
- `id_column`: The time series ID
188+
- `timestamp_column`: Future timestamps
189+
- "target_name": The name of the target column
190+
- "predictions": The point predictions generated by the model
191+
- One column for predictions at each quantile level in `quantile_levels`
192+
"""
193+
try:
194+
import pandas as pd
195+
196+
from .df_utils import convert_df_input_to_list_of_dicts_input
197+
except ImportError:
198+
raise ImportError(
199+
"pandas is required for predict_df. Please install it with `pip install pandas`."
200+
)
201+
202+
if not isinstance(target, str):
203+
raise ValueError(
204+
f"Expected `target` to be str, but found {type(target)}. {self.__class__.__name__} only supports univariate forecasting."
205+
)
206+
207+
if prediction_length is None:
208+
prediction_length = self.model_prediction_length
209+
210+
inputs, original_order, prediction_timestamps = (
211+
convert_df_input_to_list_of_dicts_input(
212+
df=df,
213+
future_df=None,
214+
id_column=id_column,
215+
timestamp_column=timestamp_column,
216+
target_columns=[target],
217+
prediction_length=prediction_length,
218+
validate_inputs=validate_inputs,
219+
)
220+
)
221+
222+
# NOTE: any covariates, if present, are ignored here
223+
context = [
224+
torch.tensor(item["target"]).squeeze(0) for item in inputs
225+
] # squeeze the extra variate dim
226+
227+
# Generate forecasts
228+
quantiles, mean = self.predict_quantiles(
229+
inputs=context,
230+
prediction_length=prediction_length,
231+
quantile_levels=quantile_levels,
232+
limit_prediction_length=False,
233+
**predict_kwargs,
234+
)
235+
236+
quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles]
237+
mean_np = mean.numpy() # [n_series, horizon]
238+
239+
results_dfs = []
240+
for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()):
241+
q_pred = quantiles_np[i] # (horizon, num_quantiles)
242+
point_pred = mean_np[i] # (horizon)
243+
244+
series_forecast_data = {
245+
id_column: series_id,
246+
timestamp_column: future_ts,
247+
"target_name": target,
248+
}
249+
series_forecast_data["predictions"] = point_pred
250+
for q_idx, q_level in enumerate(quantile_levels):
251+
series_forecast_data[str(q_level)] = q_pred[:, q_idx]
252+
253+
results_dfs.append(pd.DataFrame(series_forecast_data))
254+
255+
predictions_df = pd.concat(results_dfs, ignore_index=True)
256+
predictions_df.set_index(id_column, inplace=True)
257+
predictions_df = predictions_df.loc[original_order]
258+
predictions_df.reset_index(inplace=True)
259+
260+
return predictions_df
261+
262+
@classmethod
263+
def from_pretrained(
264+
cls,
265+
pretrained_model_name_or_path: Union[str, Path],
266+
*model_args,
267+
force_s3_download=False,
268+
**kwargs,
269+
):
270+
"""
271+
Load the model, either from a local path, S3 prefix, or from the HuggingFace Hub.
272+
Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``.
273+
"""
274+
275+
from transformers import AutoConfig
276+
277+
torch_dtype = kwargs.get("torch_dtype", "auto")
278+
if torch_dtype != "auto" and isinstance(torch_dtype, str):
279+
kwargs["torch_dtype"] = cls.dtypes[torch_dtype]
280+
281+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
282+
is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr(
283+
config, "chronos_config"
284+
)
285+
286+
if not is_valid_config:
287+
raise ValueError("Not a Chronos config file")
288+
289+
pipeline_class_name = getattr(
290+
config, "chronos_pipeline_class", "ChronosPipeline"
291+
)
292+
class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name)
293+
if class_ is None:
294+
raise ValueError(
295+
f"Trying to load unknown pipeline class: {pipeline_class_name}"
296+
)
297+
298+
return class_.from_pretrained( # type: ignore[attr-defined]
299+
pretrained_model_name_or_path, *model_args, **kwargs
300+
)

0 commit comments

Comments
 (0)