|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# |
| 18 | + |
| 19 | +from enum import Enum |
| 20 | +from pathlib import Path |
| 21 | +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union |
| 22 | + |
| 23 | +import torch |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + import pandas as pd |
| 27 | + from transformers import PreTrainedModel |
| 28 | + |
| 29 | + |
| 30 | +from iotdb.ainode.core.model.chronos2.utils import left_pad_and_stack_1D |
| 31 | + |
| 32 | + |
| 33 | +class ForecastType(Enum): |
| 34 | + SAMPLES = "samples" |
| 35 | + QUANTILES = "quantiles" |
| 36 | + |
| 37 | + |
| 38 | +class PipelineRegistry(type): |
| 39 | + REGISTRY: Dict[str, "PipelineRegistry"] = {} |
| 40 | + |
| 41 | + def __new__(cls, name, bases, attrs): |
| 42 | + """See, https://github.com/faif/python-patterns.""" |
| 43 | + new_cls = type.__new__(cls, name, bases, attrs) |
| 44 | + if name is not None: |
| 45 | + cls.REGISTRY[name] = new_cls |
| 46 | + |
| 47 | + return new_cls |
| 48 | + |
| 49 | + |
| 50 | +class BaseChronosPipeline(metaclass=PipelineRegistry): |
| 51 | + forecast_type: ForecastType |
| 52 | + dtypes = {"bfloat16": torch.bfloat16, "float32": torch.float32} |
| 53 | + |
| 54 | + def __init__(self, inner_model: "PreTrainedModel"): |
| 55 | + """ |
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + inner_model : PreTrainedModel |
| 59 | + A hugging-face transformers PreTrainedModel, e.g., T5ForConditionalGeneration |
| 60 | + """ |
| 61 | + # for easy access to the inner HF-style model |
| 62 | + self.inner_model = inner_model |
| 63 | + |
| 64 | + @property |
| 65 | + def model_context_length(self) -> int: |
| 66 | + raise NotImplementedError() |
| 67 | + |
| 68 | + @property |
| 69 | + def model_prediction_length(self) -> int: |
| 70 | + raise NotImplementedError() |
| 71 | + |
| 72 | + def _prepare_and_validate_context( |
| 73 | + self, context: Union[torch.Tensor, List[torch.Tensor]] |
| 74 | + ): |
| 75 | + if isinstance(context, list): |
| 76 | + context = left_pad_and_stack_1D(context) |
| 77 | + assert isinstance(context, torch.Tensor) |
| 78 | + if context.ndim == 1: |
| 79 | + context = context.unsqueeze(0) |
| 80 | + assert context.ndim == 2 |
| 81 | + |
| 82 | + return context |
| 83 | + |
| 84 | + def predict( |
| 85 | + self, |
| 86 | + inputs: Union[torch.Tensor, List[torch.Tensor]], |
| 87 | + prediction_length: Optional[int] = None, |
| 88 | + ): |
| 89 | + """ |
| 90 | + Get forecasts for the given time series. Predictions will be |
| 91 | + returned in fp32 on the cpu. |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + inputs |
| 96 | + Input series. This is either a 1D tensor, or a list |
| 97 | + of 1D tensors, or a 2D tensor whose first dimension |
| 98 | + is batch. In the latter case, use left-padding with |
| 99 | + ``torch.nan`` to align series of different lengths. |
| 100 | + prediction_length |
| 101 | + Time steps to predict. Defaults to a model-dependent |
| 102 | + value if not given. |
| 103 | +
|
| 104 | + Returns |
| 105 | + ------- |
| 106 | + forecasts |
| 107 | + Tensor containing forecasts. The layout and meaning |
| 108 | + of the forecasts values depends on ``self.forecast_type``. |
| 109 | + """ |
| 110 | + raise NotImplementedError() |
| 111 | + |
| 112 | + def predict_quantiles( |
| 113 | + self, |
| 114 | + inputs: Union[torch.Tensor, List[torch.Tensor]], |
| 115 | + prediction_length: Optional[int] = None, |
| 116 | + quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], |
| 117 | + **kwargs, |
| 118 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 119 | + """ |
| 120 | + Get quantile and mean forecasts for given time series. |
| 121 | + Predictions will be returned in fp32 on the cpu. |
| 122 | +
|
| 123 | + Parameters |
| 124 | + ---------- |
| 125 | + inputs : Union[torch.Tensor, List[torch.Tensor]] |
| 126 | + Input series. This is either a 1D tensor, or a list |
| 127 | + of 1D tensors, or a 2D tensor whose first dimension |
| 128 | + is batch. In the latter case, use left-padding with |
| 129 | + ``torch.nan`` to align series of different lengths. |
| 130 | + prediction_length : Optional[int], optional |
| 131 | + Time steps to predict. Defaults to a model-dependent |
| 132 | + value if not given. |
| 133 | + quantile_levels : List[float], optional |
| 134 | + Quantile levels to compute, by default [0.1, 0.2, ..., 0.9] |
| 135 | +
|
| 136 | + Returns |
| 137 | + ------- |
| 138 | + quantiles |
| 139 | + Tensor containing quantile forecasts. Shape |
| 140 | + (batch_size, prediction_length, num_quantiles) |
| 141 | + mean |
| 142 | + Tensor containing mean (point) forecasts. Shape |
| 143 | + (batch_size, prediction_length) |
| 144 | + """ |
| 145 | + raise NotImplementedError() |
| 146 | + |
| 147 | + def predict_df( |
| 148 | + self, |
| 149 | + df: "pd.DataFrame", |
| 150 | + *, |
| 151 | + id_column: str = "item_id", |
| 152 | + timestamp_column: str = "timestamp", |
| 153 | + target: str = "target", |
| 154 | + prediction_length: int | None = None, |
| 155 | + quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], |
| 156 | + validate_inputs: bool = True, |
| 157 | + **predict_kwargs, |
| 158 | + ) -> "pd.DataFrame": |
| 159 | + """ |
| 160 | + Perform forecasting on time series data in a long-format pandas DataFrame. |
| 161 | +
|
| 162 | + Parameters |
| 163 | + ---------- |
| 164 | + df |
| 165 | + Time series data in long format with an id column, a timestamp, and one target column. |
| 166 | + Any other columns, if present, will be ignored |
| 167 | + id_column |
| 168 | + The name of the column which contains the unique time series identifiers, by default "item_id" |
| 169 | + timestamp_column |
| 170 | + The name of the column which contains timestamps, by default "timestamp" |
| 171 | + All time series in the dataframe must have regular timestamps with the same frequency (no gaps) |
| 172 | + target |
| 173 | + The name of the column which contains the target variables to be forecasted, by default "target" |
| 174 | + prediction_length |
| 175 | + Number of steps to predict for each time series |
| 176 | + quantile_levels |
| 177 | + Quantile levels to compute |
| 178 | + validate_inputs |
| 179 | + When True, the dataframe(s) will be validated before prediction, ensuring that timestamps have a |
| 180 | + regular frequency, and item IDs match between past and future data. Setting to False disables these checks. |
| 181 | + **predict_kwargs |
| 182 | + Additional arguments passed to predict_quantiles |
| 183 | +
|
| 184 | + Returns |
| 185 | + ------- |
| 186 | + The forecasts dataframe generated by the model with the following columns |
| 187 | + - `id_column`: The time series ID |
| 188 | + - `timestamp_column`: Future timestamps |
| 189 | + - "target_name": The name of the target column |
| 190 | + - "predictions": The point predictions generated by the model |
| 191 | + - One column for predictions at each quantile level in `quantile_levels` |
| 192 | + """ |
| 193 | + try: |
| 194 | + import pandas as pd |
| 195 | + |
| 196 | + from .df_utils import convert_df_input_to_list_of_dicts_input |
| 197 | + except ImportError: |
| 198 | + raise ImportError( |
| 199 | + "pandas is required for predict_df. Please install it with `pip install pandas`." |
| 200 | + ) |
| 201 | + |
| 202 | + if not isinstance(target, str): |
| 203 | + raise ValueError( |
| 204 | + f"Expected `target` to be str, but found {type(target)}. {self.__class__.__name__} only supports univariate forecasting." |
| 205 | + ) |
| 206 | + |
| 207 | + if prediction_length is None: |
| 208 | + prediction_length = self.model_prediction_length |
| 209 | + |
| 210 | + inputs, original_order, prediction_timestamps = ( |
| 211 | + convert_df_input_to_list_of_dicts_input( |
| 212 | + df=df, |
| 213 | + future_df=None, |
| 214 | + id_column=id_column, |
| 215 | + timestamp_column=timestamp_column, |
| 216 | + target_columns=[target], |
| 217 | + prediction_length=prediction_length, |
| 218 | + validate_inputs=validate_inputs, |
| 219 | + ) |
| 220 | + ) |
| 221 | + |
| 222 | + # NOTE: any covariates, if present, are ignored here |
| 223 | + context = [ |
| 224 | + torch.tensor(item["target"]).squeeze(0) for item in inputs |
| 225 | + ] # squeeze the extra variate dim |
| 226 | + |
| 227 | + # Generate forecasts |
| 228 | + quantiles, mean = self.predict_quantiles( |
| 229 | + inputs=context, |
| 230 | + prediction_length=prediction_length, |
| 231 | + quantile_levels=quantile_levels, |
| 232 | + limit_prediction_length=False, |
| 233 | + **predict_kwargs, |
| 234 | + ) |
| 235 | + |
| 236 | + quantiles_np = quantiles.numpy() # [n_series, horizon, num_quantiles] |
| 237 | + mean_np = mean.numpy() # [n_series, horizon] |
| 238 | + |
| 239 | + results_dfs = [] |
| 240 | + for i, (series_id, future_ts) in enumerate(prediction_timestamps.items()): |
| 241 | + q_pred = quantiles_np[i] # (horizon, num_quantiles) |
| 242 | + point_pred = mean_np[i] # (horizon) |
| 243 | + |
| 244 | + series_forecast_data = { |
| 245 | + id_column: series_id, |
| 246 | + timestamp_column: future_ts, |
| 247 | + "target_name": target, |
| 248 | + } |
| 249 | + series_forecast_data["predictions"] = point_pred |
| 250 | + for q_idx, q_level in enumerate(quantile_levels): |
| 251 | + series_forecast_data[str(q_level)] = q_pred[:, q_idx] |
| 252 | + |
| 253 | + results_dfs.append(pd.DataFrame(series_forecast_data)) |
| 254 | + |
| 255 | + predictions_df = pd.concat(results_dfs, ignore_index=True) |
| 256 | + predictions_df.set_index(id_column, inplace=True) |
| 257 | + predictions_df = predictions_df.loc[original_order] |
| 258 | + predictions_df.reset_index(inplace=True) |
| 259 | + |
| 260 | + return predictions_df |
| 261 | + |
| 262 | + @classmethod |
| 263 | + def from_pretrained( |
| 264 | + cls, |
| 265 | + pretrained_model_name_or_path: Union[str, Path], |
| 266 | + *model_args, |
| 267 | + force_s3_download=False, |
| 268 | + **kwargs, |
| 269 | + ): |
| 270 | + """ |
| 271 | + Load the model, either from a local path, S3 prefix, or from the HuggingFace Hub. |
| 272 | + Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. |
| 273 | + """ |
| 274 | + |
| 275 | + from transformers import AutoConfig |
| 276 | + |
| 277 | + torch_dtype = kwargs.get("torch_dtype", "auto") |
| 278 | + if torch_dtype != "auto" and isinstance(torch_dtype, str): |
| 279 | + kwargs["torch_dtype"] = cls.dtypes[torch_dtype] |
| 280 | + |
| 281 | + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| 282 | + is_valid_config = hasattr(config, "chronos_pipeline_class") or hasattr( |
| 283 | + config, "chronos_config" |
| 284 | + ) |
| 285 | + |
| 286 | + if not is_valid_config: |
| 287 | + raise ValueError("Not a Chronos config file") |
| 288 | + |
| 289 | + pipeline_class_name = getattr( |
| 290 | + config, "chronos_pipeline_class", "ChronosPipeline" |
| 291 | + ) |
| 292 | + class_ = PipelineRegistry.REGISTRY.get(pipeline_class_name) |
| 293 | + if class_ is None: |
| 294 | + raise ValueError( |
| 295 | + f"Trying to load unknown pipeline class: {pipeline_class_name}" |
| 296 | + ) |
| 297 | + |
| 298 | + return class_.from_pretrained( # type: ignore[attr-defined] |
| 299 | + pretrained_model_name_or_path, *model_args, **kwargs |
| 300 | + ) |
0 commit comments