Skip to content

Commit 1a2498f

Browse files
authored
Backports for v2.0.1 and version bump (#369)
*Issue #, if available:* *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 7a8427d commit 1a2498f

File tree

3 files changed

+39
-25
lines changed

3 files changed

+39
-25
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ license = { file = "LICENSE" }
1515
requires-python = ">=3.10"
1616
dependencies = [
1717
"torch>=2.0,<3",
18-
"transformers>=4.49,<5",
18+
"transformers>=4.41,<5",
1919
"accelerate>=0.34,<2",
2020
"numpy>=1.21,<3",
2121
"einops>=0.7.0,<1",

src/chronos/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.0.0"
1+
__version__ = "2.0.1"

src/chronos/chronos_bolt.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline):
407407
def __init__(self, model: ChronosBoltModelForForecasting):
408408
super().__init__(inner_model=model) # type: ignore
409409
self.model = model
410+
self.model_context_length: int = self.model.config.chronos_config["context_length"]
411+
self.model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
410412

411413
@property
412414
def quantiles(self) -> List[float]:
@@ -487,14 +489,12 @@ def predict(
487489
"""
488490
context_tensor = self._prepare_and_validate_context(context=inputs)
489491

490-
model_context_length: int = self.model.config.chronos_config["context_length"]
491-
model_prediction_length: int = self.model.config.chronos_config["prediction_length"]
492492
if prediction_length is None:
493-
prediction_length = model_prediction_length
493+
prediction_length = self.model_prediction_length
494494

495-
if prediction_length > model_prediction_length:
495+
if prediction_length > self.model_prediction_length:
496496
msg = (
497-
f"We recommend keeping prediction length <= {model_prediction_length}. "
497+
f"We recommend keeping prediction length <= {self.model_prediction_length}. "
498498
"The quality of longer predictions may degrade since the model is not optimized for it. "
499499
)
500500
if limit_prediction_length:
@@ -507,33 +507,47 @@ def predict(
507507

508508
# We truncate the context here because otherwise batches with very long
509509
# context could take up large amounts of GPU memory unnecessarily.
510-
if context_tensor.shape[-1] > model_context_length:
511-
context_tensor = context_tensor[..., -model_context_length:]
510+
if context_tensor.shape[-1] > self.model_context_length:
511+
context_tensor = context_tensor[..., -self.model_context_length :]
512512

513-
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
514-
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
515-
# every 64 steps.
516-
context_tensor = context_tensor.to(
517-
device=self.model.device,
518-
dtype=torch.float32,
519-
)
513+
context_tensor = context_tensor.to(device=self.model.device, dtype=torch.float32)
514+
# First block prediction
515+
with torch.no_grad():
516+
prediction: torch.Tensor = self.model(context=context_tensor).quantile_preds.to(context_tensor)
517+
518+
predictions.append(prediction)
519+
remaining -= prediction.shape[-1]
520+
521+
# NOTE: The following heuristic for better prediction intervals with long-horizon forecasts
522+
# uses all quantiles generated by the model for the first `model_prediction_length` steps,
523+
# concatenating each quantile with the context and generating the next `model_prediction_length` steps.
524+
# The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles`
525+
# by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles`
526+
# when the `prediction_length` is greater than `model_prediction_length`.
527+
528+
if remaining > 0:
529+
# Expand the context along quantile axis
530+
context_tensor = context_tensor.unsqueeze(1).repeat(1, len(self.quantiles), 1)
531+
532+
quantile_tensor = torch.tensor(self.quantiles, device=context_tensor.device)
520533
while remaining > 0:
534+
# Append the prediction to context
535+
context_tensor = torch.cat([context_tensor, prediction], dim=-1)[..., -self.model_context_length :]
536+
(batch_size, n_quantiles, context_length) = context_tensor.shape
537+
521538
with torch.no_grad():
539+
# Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)
522540
prediction = self.model(
523-
context=context_tensor,
541+
context=context_tensor.reshape(batch_size * n_quantiles, context_length)
524542
).quantile_preds.to(context_tensor)
543+
# Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length)
544+
prediction = prediction.reshape(batch_size, n_quantiles * n_quantiles, -1)
545+
# Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length)
546+
prediction = torch.quantile(prediction, q=quantile_tensor, dim=1).transpose(0, 1)
525547

526548
predictions.append(prediction)
527549
remaining -= prediction.shape[-1]
528550

529-
if remaining <= 0:
530-
break
531-
532-
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
533-
central_prediction = prediction[:, central_idx]
534-
535-
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
536-
537551
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32, device="cpu")
538552

539553
def predict_quantiles(

0 commit comments

Comments
 (0)