@@ -407,6 +407,8 @@ class ChronosBoltPipeline(BaseChronosPipeline):
407407 def __init__ (self , model : ChronosBoltModelForForecasting ):
408408 super ().__init__ (inner_model = model ) # type: ignore
409409 self .model = model
410+ self .model_context_length : int = self .model .config .chronos_config ["context_length" ]
411+ self .model_prediction_length : int = self .model .config .chronos_config ["prediction_length" ]
410412
411413 @property
412414 def quantiles (self ) -> List [float ]:
@@ -487,14 +489,12 @@ def predict(
487489 """
488490 context_tensor = self ._prepare_and_validate_context (context = inputs )
489491
490- model_context_length : int = self .model .config .chronos_config ["context_length" ]
491- model_prediction_length : int = self .model .config .chronos_config ["prediction_length" ]
492492 if prediction_length is None :
493- prediction_length = model_prediction_length
493+ prediction_length = self . model_prediction_length
494494
495- if prediction_length > model_prediction_length :
495+ if prediction_length > self . model_prediction_length :
496496 msg = (
497- f"We recommend keeping prediction length <= { model_prediction_length } . "
497+ f"We recommend keeping prediction length <= { self . model_prediction_length } . "
498498 "The quality of longer predictions may degrade since the model is not optimized for it. "
499499 )
500500 if limit_prediction_length :
@@ -507,33 +507,47 @@ def predict(
507507
508508 # We truncate the context here because otherwise batches with very long
509509 # context could take up large amounts of GPU memory unnecessarily.
510- if context_tensor .shape [- 1 ] > model_context_length :
511- context_tensor = context_tensor [..., - model_context_length :]
510+ if context_tensor .shape [- 1 ] > self . model_context_length :
511+ context_tensor = context_tensor [..., - self . model_context_length :]
512512
513- # TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
514- # horizon that the model was trained with (i.e., 64). This results in variance collapsing
515- # every 64 steps.
516- context_tensor = context_tensor .to (
517- device = self .model .device ,
518- dtype = torch .float32 ,
519- )
513+ context_tensor = context_tensor .to (device = self .model .device , dtype = torch .float32 )
514+ # First block prediction
515+ with torch .no_grad ():
516+ prediction : torch .Tensor = self .model (context = context_tensor ).quantile_preds .to (context_tensor )
517+
518+ predictions .append (prediction )
519+ remaining -= prediction .shape [- 1 ]
520+
521+ # NOTE: The following heuristic for better prediction intervals with long-horizon forecasts
522+ # uses all quantiles generated by the model for the first `model_prediction_length` steps,
523+ # concatenating each quantile with the context and generating the next `model_prediction_length` steps.
524+ # The `num_quantiles * num_quantiles` "samples" thus generated are then reduced to `num_quantiles`
525+ # by computing empirical quantiles. Note that this option scales the batch size by `num_quantiles`
526+ # when the `prediction_length` is greater than `model_prediction_length`.
527+
528+ if remaining > 0 :
529+ # Expand the context along quantile axis
530+ context_tensor = context_tensor .unsqueeze (1 ).repeat (1 , len (self .quantiles ), 1 )
531+
532+ quantile_tensor = torch .tensor (self .quantiles , device = context_tensor .device )
520533 while remaining > 0 :
534+ # Append the prediction to context
535+ context_tensor = torch .cat ([context_tensor , prediction ], dim = - 1 )[..., - self .model_context_length :]
536+ (batch_size , n_quantiles , context_length ) = context_tensor .shape
537+
521538 with torch .no_grad ():
539+ # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)
522540 prediction = self .model (
523- context = context_tensor ,
541+ context = context_tensor . reshape ( batch_size * n_quantiles , context_length )
524542 ).quantile_preds .to (context_tensor )
543+ # Reshape predictions from (batch * n_quantiles, n_quantiles, model_prediction_length) to (batch, n_quantiles * n_quantiles, model_prediction_length)
544+ prediction = prediction .reshape (batch_size , n_quantiles * n_quantiles , - 1 )
545+ # Reduce `n_quantiles * n_quantiles` to n_quantiles and transpose back to (batch_size, n_quantiles, model_prediction_length)
546+ prediction = torch .quantile (prediction , q = quantile_tensor , dim = 1 ).transpose (0 , 1 )
525547
526548 predictions .append (prediction )
527549 remaining -= prediction .shape [- 1 ]
528550
529- if remaining <= 0 :
530- break
531-
532- central_idx = torch .abs (torch .tensor (self .quantiles ) - 0.5 ).argmin ()
533- central_prediction = prediction [:, central_idx ]
534-
535- context_tensor = torch .cat ([context_tensor , central_prediction ], dim = - 1 )
536-
537551 return torch .cat (predictions , dim = - 1 )[..., :prediction_length ].to (dtype = torch .float32 , device = "cpu" )
538552
539553 def predict_quantiles (
0 commit comments