@@ -3009,6 +3009,57 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
30093009//////////////////////////////////////////////////
30103010
30113011
3012+ //////////////////////////////////////////////////
3013+ // Chronos2 models
3014+ /**
3015+ * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
3016+ */
3017+ export class Chronos2PreTrainedModel extends PreTrainedModel {
3018+ forward_params = [
3019+ 'context' ,
3020+ 'group_ids' ,
3021+ 'attention_mask' ,
3022+ ] ;
3023+ } ;
3024+
3025+ /**
3026+ * The Chronos-2 Model for time series forecasting.
3027+ *
3028+ * Chronos-2 is a family of pretrained time series forecasting models based on T5.
3029+ * It uses a patching mechanism to convert time series into tokens and predicts
3030+ * multiple quantiles for probabilistic forecasting.
3031+ *
3032+ * **Example:** Load and run a Chronos-2 model for forecasting.
3033+ *
3034+ * ```javascript
3035+ * import { Chronos2ForForecasting } from '@huggingface/transformers';
3036+ *
3037+ * const model = await Chronos2ForForecasting.from_pretrained('amazon/chronos-2-small');
3038+ *
3039+ * // Prepare time series input
3040+ * const context = new Float32Array([1.0, 2.0, 3.0, 4.0, ...]); // Your historical data
3041+ * const inputs = {
3042+ * context: context,
3043+ * group_ids: new BigInt64Array([0]), // Group ID for cross-learning
3044+ * attention_mask: new Float32Array(context.length).fill(1.0),
3045+ * };
3046+ *
3047+ * // Generate forecasts
3048+ * const { quantile_preds } = await model(inputs);
3049+ * // Returns quantile predictions: [batch_size, num_quantiles, prediction_length]
3050+ * ```
3051+ */
3052+ export class Chronos2Model extends Chronos2PreTrainedModel { }
3053+
3054+ /**
3055+ * Chronos2 Model with a forecasting head for time series prediction.
3056+ *
3057+ * This model outputs quantile predictions for probabilistic forecasting.
3058+ */
3059+ export class Chronos2ForForecasting extends Chronos2PreTrainedModel { }
3060+ //////////////////////////////////////////////////
3061+
3062+
30123063//////////////////////////////////////////////////
30133064// LONGT5 models
30143065/**
@@ -7839,6 +7890,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
78397890
78407891const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map ( [
78417892 [ 't5' , [ 'T5Model' , T5Model ] ] ,
7893+ [ 'chronos2' , [ 'Chronos2Model' , Chronos2Model ] ] ,
78427894 [ 'longt5' , [ 'LongT5Model' , LongT5Model ] ] ,
78437895 [ 'mt5' , [ 'MT5Model' , MT5Model ] ] ,
78447896 [ 'bart' , [ 'BartModel' , BartModel ] ] ,
@@ -7957,6 +8009,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
79578009
79588010const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map ( [
79598011 [ 't5' , [ 'T5ForConditionalGeneration' , T5ForConditionalGeneration ] ] ,
8012+ [ 'chronos2' , [ 'Chronos2ForForecasting' , Chronos2ForForecasting ] ] ,
79608013 [ 'longt5' , [ 'LongT5ForConditionalGeneration' , LongT5ForConditionalGeneration ] ] ,
79618014 [ 'mt5' , [ 'MT5ForConditionalGeneration' , MT5ForConditionalGeneration ] ] ,
79628015 [ 'bart' , [ 'BartForConditionalGeneration' , BartForConditionalGeneration ] ] ,
@@ -8226,6 +8279,10 @@ const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
82268279 [ 'jina_clip' , [ 'JinaCLIPVisionModel' , JinaCLIPVisionModel ] ] ,
82278280] )
82288281
8282+ const MODEL_FOR_FORECASTING_MAPPING_NAMES = new Map ( [
8283+ [ 'chronos2' , [ 'Chronos2ForForecasting' , Chronos2ForForecasting ] ] ,
8284+ ] )
8285+
82298286const MODEL_CLASS_TYPE_MAPPING = [
82308287 // MODEL_MAPPING_NAMES:
82318288 [ MODEL_MAPPING_NAMES_ENCODER_ONLY , MODEL_TYPES . EncoderOnly ] ,
@@ -8263,6 +8320,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
82638320 [ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
82648321 [ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
82658322 [ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
8323+ [ MODEL_FOR_FORECASTING_MAPPING_NAMES , MODEL_TYPES . Seq2Seq ] ,
82668324
82678325 // Custom:
82688326 [ MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
@@ -8565,6 +8623,30 @@ export class AutoModelForAudioTextToText extends PretrainedMixin {
85658623 static MODEL_CLASS_MAPPINGS = [ MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES ] ;
85668624}
85678625
8626+ /**
8627+ * Helper class which is used to instantiate time series forecasting models with the `from_pretrained` function.
8628+ *
8629+ * @example
8630+ * const model = await AutoModelForForecasting.from_pretrained('amazon/chronos-2-small');
8631+ */
8632+ export class AutoModelForForecasting extends PretrainedMixin {
8633+ static MODEL_CLASS_MAPPINGS = [ MODEL_FOR_FORECASTING_MAPPING_NAMES ] ;
8634+
8635+ /** @type {typeof PreTrainedModel.from_pretrained } */
8636+ static async from_pretrained ( pretrained_model_name_or_path , options = { } ) {
8637+ // First, load the config to check if it has chronos_config
8638+ const config = options . config || await AutoConfig . from_pretrained ( pretrained_model_name_or_path , options ) ;
8639+
8640+ // If model has chronos_config, route to Chronos2ForForecasting regardless of model_type
8641+ if ( config . chronos_config ) {
8642+ return await Chronos2ForForecasting . from_pretrained ( pretrained_model_name_or_path , { ...options , config } ) ;
8643+ }
8644+
8645+ // Otherwise, use the standard mapping-based routing
8646+ return await super . from_pretrained ( pretrained_model_name_or_path , { ...options , config } ) ;
8647+ }
8648+ }
8649+
85688650//////////////////////////////////////////////////
85698651
85708652//////////////////////////////////////////////////
0 commit comments