Skip to content

Commit bb1e5b1

Browse files
committed
add AutoModelForForecasting
1 parent 4c908ec commit bb1e5b1

File tree

3 files changed

+587
-0
lines changed

3 files changed

+587
-0
lines changed

src/configs.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ function getNormalizedConfig(config) {
179179

180180
// Encoder-decoder models
181181
case 't5':
182+
case 'chronos2':
182183
case 'mt5':
183184
case 'longt5':
184185
mapping['num_decoder_layers'] = 'num_decoder_layers';

src/models.js

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,57 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
30093009
//////////////////////////////////////////////////
30103010

30113011

3012+
//////////////////////////////////////////////////
3013+
// Chronos2 models
3014+
/**
3015+
* An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
3016+
*/
3017+
export class Chronos2PreTrainedModel extends PreTrainedModel {
3018+
forward_params = [
3019+
'context',
3020+
'group_ids',
3021+
'attention_mask',
3022+
];
3023+
};
3024+
3025+
/**
3026+
* The Chronos-2 Model for time series forecasting.
3027+
*
3028+
* Chronos-2 is a family of pretrained time series forecasting models based on T5.
3029+
* It uses a patching mechanism to convert time series into tokens and predicts
3030+
* multiple quantiles for probabilistic forecasting.
3031+
*
3032+
* **Example:** Load and run a Chronos-2 model for forecasting.
3033+
*
3034+
* ```javascript
3035+
* import { Chronos2ForForecasting } from '@huggingface/transformers';
3036+
*
3037+
* const model = await Chronos2ForForecasting.from_pretrained('amazon/chronos-2-small');
3038+
*
3039+
* // Prepare time series input
3040+
* const context = new Float32Array([1.0, 2.0, 3.0, 4.0, ...]); // Your historical data
3041+
* const inputs = {
3042+
* context: context,
3043+
* group_ids: new BigInt64Array([0]), // Group ID for cross-learning
3044+
* attention_mask: new Float32Array(context.length).fill(1.0),
3045+
* };
3046+
*
3047+
* // Generate forecasts
3048+
* const { quantile_preds } = await model(inputs);
3049+
* // Returns quantile predictions: [batch_size, num_quantiles, prediction_length]
3050+
* ```
3051+
*/
3052+
export class Chronos2Model extends Chronos2PreTrainedModel { }
3053+
3054+
/**
3055+
* Chronos2 Model with a forecasting head for time series prediction.
3056+
*
3057+
* This model outputs quantile predictions for probabilistic forecasting.
3058+
*/
3059+
export class Chronos2ForForecasting extends Chronos2PreTrainedModel { }
3060+
//////////////////////////////////////////////////
3061+
3062+
30123063
//////////////////////////////////////////////////
30133064
// LONGT5 models
30143065
/**
@@ -7839,6 +7890,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
78397890

78407891
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
78417892
['t5', ['T5Model', T5Model]],
7893+
['chronos2', ['Chronos2Model', Chronos2Model]],
78427894
['longt5', ['LongT5Model', LongT5Model]],
78437895
['mt5', ['MT5Model', MT5Model]],
78447896
['bart', ['BartModel', BartModel]],
@@ -7957,6 +8009,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
79578009

79588010
const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
79598011
['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]],
8012+
['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]],
79608013
['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]],
79618014
['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]],
79628015
['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]],
@@ -8226,6 +8279,10 @@ const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
82268279
['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]],
82278280
])
82288281

8282+
const MODEL_FOR_FORECASTING_MAPPING_NAMES = new Map([
8283+
['chronos2', ['Chronos2ForForecasting', Chronos2ForForecasting]],
8284+
])
8285+
82298286
const MODEL_CLASS_TYPE_MAPPING = [
82308287
// MODEL_MAPPING_NAMES:
82318288
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
@@ -8263,6 +8320,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
82638320
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
82648321
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
82658322
[MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
8323+
[MODEL_FOR_FORECASTING_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
82668324

82678325
// Custom:
82688326
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -8565,6 +8623,30 @@ export class AutoModelForAudioTextToText extends PretrainedMixin {
85658623
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES];
85668624
}
85678625

8626+
/**
8627+
* Helper class which is used to instantiate time series forecasting models with the `from_pretrained` function.
8628+
*
8629+
* @example
8630+
* const model = await AutoModelForForecasting.from_pretrained('amazon/chronos-2-small');
8631+
*/
8632+
export class AutoModelForForecasting extends PretrainedMixin {
8633+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_FORECASTING_MAPPING_NAMES];
8634+
8635+
/** @type {typeof PreTrainedModel.from_pretrained} */
8636+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
8637+
// First, load the config to check if it has chronos_config
8638+
const config = options.config || await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
8639+
8640+
// If model has chronos_config, route to Chronos2ForForecasting regardless of model_type
8641+
if (config.chronos_config) {
8642+
return await Chronos2ForForecasting.from_pretrained(pretrained_model_name_or_path, { ...options, config });
8643+
}
8644+
8645+
// Otherwise, use the standard mapping-based routing
8646+
return await super.from_pretrained(pretrained_model_name_or_path, { ...options, config });
8647+
}
8648+
}
8649+
85688650
//////////////////////////////////////////////////
85698651

85708652
//////////////////////////////////////////////////

0 commit comments

Comments
 (0)