Skip to content

Commit 72b815c

Browse files
authored
Add support for Gemma3n (#1348)
* Align llava processor with python library * Add support for llava_qwen2 * Update llava unit tests * Fix test * Update florence2 processor & tests * Update florence2 unit tests * Add support for gemma3n * Pass input_features_mask to audio encoder * Implement gemma3n feature extraction * npm audit fix * Add model to supported list * Fix JSDoc
1 parent 1f49a13 commit 72b815c

19 files changed

+387
-67
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
332332
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
333333
1. **[Gemma2](https://huggingface.co/docs/transformers/main/model_doc/gemma2)** (from Google) released with the paper [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by the Gemma Google team.
334334
1. **[Gemma3](https://huggingface.co/docs/transformers/main/model_doc/gemma3)** (from Google) released with the paper [Introducing Gemma 3: The most capable model you can run on a single GPU or TPU](https://blog.google/technology/developers/gemma-3/) by the Gemma Google team.
335+
1. **[Gemma3n](https://huggingface.co/docs/transformers/main/model_doc/gemma3n)** (from Google) released with the paper [Announcing Gemma 3n preview: powerful, efficient, mobile-first AI](https://developers.googleblog.com/en/introducing-gemma-3n/) by the Gemma Google team.
335336
1. **[GLM](https://huggingface.co/docs/transformers/main/model_doc/glm)** (from the GLM Team, THUDM & ZhipuAI) released with the paper [ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools](https://huggingface.co/papers/2406.12793v2) by Team GLM: Aohan Zeng, Bin Xu, Bowen Wang, Chenhui Zhang, Da Yin, Dan Zhang, Diego Rojas, Guanyu Feng, Hanlin Zhao, Hanyu Lai, Hao Yu, Hongning Wang, Jiadai Sun, Jiajie Zhang, Jiale Cheng, Jiayi Gui, Jie Tang, Jing Zhang, Jingyu Sun, Juanzi Li, Lei Zhao, Lindong Wu, Lucen Zhong, Mingdao Liu, Minlie Huang, Peng Zhang, Qinkai Zheng, Rui Lu, Shuaiqi Duan, Shudan Zhang, Shulin Cao, Shuxun Yang, Weng Lam Tam, Wenyi Zhao, Xiao Liu, Xiao Xia, Xiaohan Zhang, Xiaotao Gu, Xin Lv, Xinghan Liu, Xinyi Liu, Xinyue Yang, Xixuan Song, Xunkai Zhang, Yifan An, Yifan Xu, Yilin Niu, Yuantao Yang, Yueyan Li, Yushi Bai, Yuxiao Dong, Zehan Qi, Zhaoyu Wang, Zhen Yang, Zhengxiao Du, Zhenyu Hou, Zihan Wang.
336337
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://huggingface.co/papers/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
337338
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.

docs/snippets/6_supported-models.snippet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
4747
1. **[Gemma2](https://huggingface.co/docs/transformers/main/model_doc/gemma2)** (from Google) released with the paper [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by the Gemma Google team.
4848
1. **[Gemma3](https://huggingface.co/docs/transformers/main/model_doc/gemma3)** (from Google) released with the paper [Introducing Gemma 3: The most capable model you can run on a single GPU or TPU](https://blog.google/technology/developers/gemma-3/) by the Gemma Google team.
49+
1. **[Gemma3n](https://huggingface.co/docs/transformers/main/model_doc/gemma3n)** (from Google) released with the paper [Announcing Gemma 3n preview: powerful, efficient, mobile-first AI](https://developers.googleblog.com/en/introducing-gemma-3n/) by the Gemma Google team.
4950
1. **[GLM](https://huggingface.co/docs/transformers/main/model_doc/glm)** (from the GLM Team, THUDM & ZhipuAI) released with the paper [ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools](https://huggingface.co/papers/2406.12793v2) by Team GLM: Aohan Zeng, Bin Xu, Bowen Wang, Chenhui Zhang, Da Yin, Dan Zhang, Diego Rojas, Guanyu Feng, Hanlin Zhao, Hanyu Lai, Hao Yu, Hongning Wang, Jiadai Sun, Jiajie Zhang, Jiale Cheng, Jiayi Gui, Jie Tang, Jing Zhang, Jingyu Sun, Juanzi Li, Lei Zhao, Lindong Wu, Lucen Zhong, Mingdao Liu, Minlie Huang, Peng Zhang, Qinkai Zheng, Rui Lu, Shuaiqi Duan, Shudan Zhang, Shulin Cao, Shuxun Yang, Weng Lam Tam, Wenyi Zhao, Xiao Liu, Xiao Xia, Xiaohan Zhang, Xiaotao Gu, Xin Lv, Xinghan Liu, Xinyi Liu, Xinyue Yang, Xixuan Song, Xunkai Zhang, Yifan An, Yifan Xu, Yilin Niu, Yuantao Yang, Yueyan Li, Yushi Bai, Yuxiao Dong, Zehan Qi, Zhaoyu Wang, Zhen Yang, Zhengxiao Du, Zhenyu Hou, Zihan Wang.
5051
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://huggingface.co/papers/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
5152
1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.

package-lock.json

Lines changed: 23 additions & 38 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/base/processing_utils.js

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
*
2020
* @module processors
2121
*/
22-
import { PROCESSOR_NAME } from '../utils/constants.js';
22+
import { PROCESSOR_NAME, CHAT_TEMPLATE_NAME } from '../utils/constants.js';
2323
import {
2424
Callable,
2525
} from '../utils/generic.js';
26-
import { getModelJSON } from '../utils/hub.js';
26+
import { getModelJSON, getModelText } from '../utils/hub.js';
2727

2828
/**
2929
* @typedef {Object} ProcessorProperties Additional processor-specific properties.
@@ -42,16 +42,19 @@ export class Processor extends Callable {
4242
'feature_extractor_class',
4343
]
4444
static uses_processor_config = false;
45+
static uses_chat_template_file = false;
4546

4647
/**
4748
* Creates a new Processor with the given components
4849
* @param {Object} config
4950
* @param {Record<string, Object>} components
51+
* @param {string} chat_template
5052
*/
51-
constructor(config, components) {
53+
constructor(config, components, chat_template) {
5254
super();
5355
this.config = config;
5456
this.components = components;
57+
this.chat_template = chat_template;
5558
}
5659

5760
/**
@@ -86,6 +89,7 @@ export class Processor extends Callable {
8689
}
8790
return this.tokenizer.apply_chat_template(messages, {
8891
tokenize: false, // default to false
92+
chat_template: this.chat_template ?? undefined,
8993
...options,
9094
});
9195
}
@@ -146,7 +150,7 @@ export class Processor extends Callable {
146150
*/
147151
static async from_pretrained(pretrained_model_name_or_path, options) {
148152

149-
const [config, components] = await Promise.all([
153+
const [config, components, chat_template] = await Promise.all([
150154
// TODO:
151155
this.uses_processor_config
152156
? getModelJSON(pretrained_model_name_or_path, PROCESSOR_NAME, true, options)
@@ -158,9 +162,12 @@ export class Processor extends Callable {
158162
const component = await this[cls].from_pretrained(pretrained_model_name_or_path, options);
159163
return [cls.replace(/_class$/, ''), component];
160164
})
161-
).then(Object.fromEntries)
165+
).then(Object.fromEntries),
166+
this.uses_chat_template_file
167+
? getModelText(pretrained_model_name_or_path, CHAT_TEMPLATE_NAME, true, options)
168+
: null,
162169
]);
163170

164-
return new this(config, components);
171+
return new this(config, components, chat_template);
165172
}
166173
}

src/configs.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ function getNormalizedConfig(config) {
7373
case 'idefics3':
7474
case 'ultravox':
7575
case 'smolvlm':
76+
case 'gemma3n':
7677
// @ts-expect-error TS2339
7778
init_normalized_config = getNormalizedConfig(config.text_config);
7879
break;
@@ -130,6 +131,7 @@ function getNormalizedConfig(config) {
130131
case 'gemma':
131132
case 'gemma2':
132133
case 'gemma3_text':
134+
case 'gemma3n_text':
133135
case 'glm':
134136
case 'helium':
135137
mapping['num_heads'] = 'num_key_value_heads';

src/models.js

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ const MODEL_TYPES = {
136136
Phi3V: 9,
137137
AudioTextToText: 10,
138138
AutoEncoder: 11,
139+
ImageAudioTextToText: 12,
139140
}
140141
//////////////////////////////////////////////////
141142

@@ -1057,6 +1058,7 @@ export class PreTrainedModel extends Callable {
10571058
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
10581059
break;
10591060
case MODEL_TYPES.Phi3V:
1061+
case MODEL_TYPES.ImageAudioTextToText:
10601062
this.can_generate = true;
10611063
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
10621064
break;
@@ -1210,7 +1212,19 @@ export class PreTrainedModel extends Callable {
12101212
generation_config: 'generation_config.json',
12111213
}, options),
12121214
]);
1213-
1215+
} else if (modelType === MODEL_TYPES.ImageAudioTextToText) {
1216+
const sessions = {
1217+
embed_tokens: 'embed_tokens',
1218+
audio_encoder: 'audio_encoder',
1219+
vision_encoder: 'vision_encoder',
1220+
decoder_model_merged: 'decoder_model_merged',
1221+
}
1222+
info = await Promise.all([
1223+
constructSessions(pretrained_model_name_or_path, sessions, options),
1224+
getOptionalConfigs(pretrained_model_name_or_path, {
1225+
generation_config: 'generation_config.json',
1226+
}, options),
1227+
]);
12141228
} else if (modelType === MODEL_TYPES.Musicgen) {
12151229
info = await Promise.all([
12161230
constructSessions(pretrained_model_name_or_path, {
@@ -3795,6 +3809,114 @@ export class LlavaQwen2ForCausalLM extends LlavaPreTrainedModel {
37953809
}
37963810
}
37973811

3812+
export class Gemma3nPreTrainedModel extends PreTrainedModel {
3813+
forward_params = [
3814+
'input_ids',
3815+
'attention_mask',
3816+
'inputs_embeds',
3817+
'per_layer_inputs',
3818+
3819+
'position_ids',
3820+
'pixel_values',
3821+
'input_features',
3822+
'input_features_mask',
3823+
'past_key_values',
3824+
];
3825+
}
3826+
export class Gemma3nForConditionalGeneration extends Gemma3nPreTrainedModel {
3827+
3828+
async forward({
3829+
// Produced by the tokenizer/processor:
3830+
input_ids = null,
3831+
attention_mask = null,
3832+
pixel_values = null,
3833+
input_features = null,
3834+
input_features_mask = null,
3835+
3836+
// Used during generation:
3837+
position_ids = null,
3838+
inputs_embeds = null,
3839+
per_layer_inputs=null,
3840+
past_key_values = null,
3841+
3842+
// Generic generation parameters
3843+
generation_config = null,
3844+
logits_processor = null,
3845+
3846+
// TODO: needed?
3847+
...kwargs
3848+
}) {
3849+
if (!inputs_embeds || !per_layer_inputs) {
3850+
// 1. Extract the text embeddings.
3851+
({ inputs_embeds, per_layer_inputs} = await sessionRun(this.sessions['embed_tokens'], {
3852+
input_ids,
3853+
}));
3854+
if (input_ids.dims[1] !== 1) {
3855+
if (pixel_values) {
3856+
// Encode the image
3857+
const { image_features } = await sessionRun(this.sessions['vision_encoder'], {
3858+
pixel_values,
3859+
});
3860+
({ inputs_embeds, attention_mask } = this._merge_input_ids_with_image_features({
3861+
image_features,
3862+
inputs_embeds,
3863+
input_ids,
3864+
attention_mask,
3865+
}));
3866+
}
3867+
3868+
if (input_features) {
3869+
// Encode the audio
3870+
const { audio_features } = await sessionRun(this.sessions['audio_encoder'], {
3871+
input_features,
3872+
input_features_mask,
3873+
});
3874+
({ inputs_embeds, attention_mask } = this._merge_input_ids_with_audio_features({
3875+
audio_features,
3876+
inputs_embeds,
3877+
input_ids,
3878+
attention_mask,
3879+
}));
3880+
}
3881+
}
3882+
}
3883+
3884+
const outputs = await decoderForward(this, {
3885+
inputs_embeds,
3886+
per_layer_inputs,
3887+
past_key_values,
3888+
attention_mask,
3889+
position_ids,
3890+
generation_config,
3891+
logits_processor,
3892+
}, true);
3893+
return outputs;
3894+
}
3895+
3896+
_merge_input_ids_with_image_features(kwargs) {
3897+
const vision_hidden_size = kwargs.image_features.dims.at(-1);
3898+
const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);
3899+
return default_merge_input_ids_with_image_features({
3900+
// @ts-ignore
3901+
image_token_id: this.config.image_token_id,
3902+
...kwargs,
3903+
image_features: reshaped_image_hidden_states,
3904+
});
3905+
}
3906+
_merge_input_ids_with_audio_features(kwargs) {
3907+
const audio_hidden_size = kwargs.audio_features.dims.at(-1);
3908+
const reshaped_audio_features = kwargs.audio_features.view(-1, audio_hidden_size);
3909+
3910+
return default_merge_input_ids_with_audio_features({
3911+
// @ts-ignore
3912+
audio_token_id: this.config.audio_token_id,
3913+
...kwargs,
3914+
audio_features: reshaped_audio_features,
3915+
})
3916+
}
3917+
}
3918+
3919+
37983920
//////////////////////////////////////////////////
37993921
// Idefics3 Models
38003922
export class Idefics3PreTrainedModel extends PreTrainedModel {
@@ -7799,6 +7921,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
77997921
['smolvlm', ['SmolVLMForConditionalGeneration', SmolVLMForConditionalGeneration]],
78007922
['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]],
78017923
['llava_qwen2', ['LlavaQwen2ForCausalLM', LlavaQwen2ForCausalLM]],
7924+
['gemma3n', ['Gemma3nForConditionalGeneration', Gemma3nForConditionalGeneration]],
78027925
]);
78037926

78047927
const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
@@ -8015,6 +8138,8 @@ const CUSTOM_MAPPING = [
80158138
['MimiDecoderModel', MimiDecoderModel, MODEL_TYPES.EncoderOnly],
80168139
['SnacEncoderModel', SnacEncoderModel, MODEL_TYPES.EncoderOnly],
80178140
['SnacDecoderModel', SnacDecoderModel, MODEL_TYPES.EncoderOnly],
8141+
8142+
['Gemma3nForConditionalGeneration', Gemma3nForConditionalGeneration, MODEL_TYPES.ImageAudioTextToText],
80188143
]
80198144
for (const [name, model, type] of CUSTOM_MAPPING) {
80208145
MODEL_TYPE_MAPPING.set(name, type);

0 commit comments

Comments
 (0)