Skip to content

Commit c97afe3

Browse files
authored
Support gemma3n (#4836)
1 parent 28f77bf commit c97afe3

File tree

9 files changed

+182
-2
lines changed

9 files changed

+182
-2
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@
508508
|[LLM-Research/gemma-2-27b-it](https://modelscope.cn/models/LLM-Research/gemma-2-27b-it)|gemma2|gemma|transformers>=4.42|✘|-|[google/gemma-2-27b-it](https://huggingface.co/google/gemma-2-27b-it)|
509509
|[LLM-Research/gemma-3-1b-pt](https://modelscope.cn/models/LLM-Research/gemma-3-1b-pt)|gemma3_text|gemma3_text|transformers>=4.49|✘|-|[google/gemma-3-1b-pt](https://huggingface.co/google/gemma-3-1b-pt)|
510510
|[LLM-Research/gemma-3-1b-it](https://modelscope.cn/models/LLM-Research/gemma-3-1b-it)|gemma3_text|gemma3_text|transformers>=4.49|✘|-|[google/gemma-3-1b-it](https://huggingface.co/google/gemma-3-1b-it)|
511+
|[google/gemma-3n-E2B](https://www.modelscope.cn/models/google/gemma-3n-E2B)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E2B](https://huggingface.co/google/gemma-3n-E2B)|
512+
|[google/gemma-3n-E2B-it](https://modelscope.cn/models/google/gemma-3n-E2B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E2B-it](https://huggingface.co/google/gemma-3n-E2B-it)|
513+
|[google/gemma-3n-E4B](https://www.modelscope.cn/models/google/gemma-3n-E4B)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)|
514+
|[google/gemma-3n-E4B-it](https://www.modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
511515
|[skywork/Skywork-13B-base](https://modelscope.cn/models/skywork/Skywork-13B-base)|skywork|skywork|-|✘|-|[skywork/Skywork-13B-base](https://huggingface.co/skywork/Skywork-13B-base)|
512516
|[skywork/Skywork-13B-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat)|skywork|skywork|-|✘|-|-|
513517
|[AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B](https://modelscope.cn/models/AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B)|skywork_o1|skywork_o1|transformers>=4.43|✔|-|[Skywork/Skywork-o1-Open-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ The table below introduces the models integrated with ms-swift:
508508
|[LLM-Research/gemma-2-27b-it](https://modelscope.cn/models/LLM-Research/gemma-2-27b-it)|gemma2|gemma|transformers>=4.42|✘|-|[google/gemma-2-27b-it](https://huggingface.co/google/gemma-2-27b-it)|
509509
|[LLM-Research/gemma-3-1b-pt](https://modelscope.cn/models/LLM-Research/gemma-3-1b-pt)|gemma3_text|gemma3_text|transformers>=4.49|✘|-|[google/gemma-3-1b-pt](https://huggingface.co/google/gemma-3-1b-pt)|
510510
|[LLM-Research/gemma-3-1b-it](https://modelscope.cn/models/LLM-Research/gemma-3-1b-it)|gemma3_text|gemma3_text|transformers>=4.49|✘|-|[google/gemma-3-1b-it](https://huggingface.co/google/gemma-3-1b-it)|
511+
|[google/gemma-3n-E2B](https://www.modelscope.cn/models/google/gemma-3n-E2B)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E2B](https://huggingface.co/google/gemma-3n-E2B)|
512+
|[google/gemma-3n-E2B-it](https://modelscope.cn/models/google/gemma-3n-E2B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E2B-it](https://huggingface.co/google/gemma-3n-E2B-it)|
513+
|[google/gemma-3n-E4B](https://www.modelscope.cn/models/google/gemma-3n-E4B)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)|
514+
|[google/gemma-3n-E4B-it](https://www.modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
511515
|[skywork/Skywork-13B-base](https://modelscope.cn/models/skywork/Skywork-13B-base)|skywork|skywork|-|✘|-|[skywork/Skywork-13B-base](https://huggingface.co/skywork/Skywork-13B-base)|
512516
|[skywork/Skywork-13B-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat)|skywork|skywork|-|✘|-|-|
513517
|[AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B](https://modelscope.cn/models/AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B)|skywork_o1|skywork_o1|transformers>=4.43|✔|-|[Skywork/Skywork-o1-Open-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ class MLLMModelType:
232232
megrez_omni = 'megrez_omni'
233233
valley = 'valley'
234234
gemma3_vision = 'gemma3_vision'
235+
gemma3n = 'gemma3n'
235236
mistral_2503 = 'mistral_2503'
236237

237238

swift/llm/model/model/gemma.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from swift.llm import TemplateType
55
from ..constant import LLMModelType, MLLMModelType
66
from ..model_arch import ModelArch
7+
from ..patcher import patch_output_to_input_device
78
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
89
get_model_tokenizer_with_flash_attn, register_model)
910
from ..utils import ModelInfo
@@ -163,3 +164,38 @@ def get_model_tokenizer_gemma3_vision(model_dir: str,
163164
model_arch=ModelArch.llava_hf,
164165
requires=['transformers>=4.49'],
165166
))
167+
168+
169+
def get_model_tokenizer_gemma3n(model_dir: str,
170+
model_info: ModelInfo,
171+
model_kwargs: Dict[str, Any],
172+
load_model: bool = True,
173+
**kwargs):
174+
from transformers import Gemma3nForConditionalGeneration
175+
kwargs['automodel_class'] = kwargs['automodel_class'] or Gemma3nForConditionalGeneration
176+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
177+
178+
if load_model and model is not None:
179+
patch_output_to_input_device(model.model.embed_vision)
180+
patch_output_to_input_device(model.model.embed_audio)
181+
182+
return model, processor
183+
184+
185+
register_model(
186+
ModelMeta(
187+
MLLMModelType.gemma3n,
188+
[
189+
ModelGroup([
190+
Model('google/gemma-3n-E2B', 'google/gemma-3n-E2B'),
191+
Model('google/gemma-3n-E4B', 'google/gemma-3n-E4B'),
192+
Model('google/gemma-3n-E2B-it', 'google/gemma-3n-E2B-it'),
193+
Model('google/gemma-3n-E4B-it', 'google/gemma-3n-E4B-it'),
194+
], ),
195+
],
196+
TemplateType.gemma3n,
197+
get_model_tokenizer_gemma3n,
198+
architectures=['Gemma3nForConditionalGeneration'],
199+
model_arch=ModelArch.gemma3n,
200+
requires=['transformers>=4.53.1'],
201+
))

swift/llm/model/model_arch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class MLLMModelArch:
7070
emu3_chat = 'emu3_chat'
7171
megrez_omni = 'megrez_omni'
7272
valley = 'valley'
73+
gemma3n = 'gemma3n'
7374
mistral_2503 = 'mistral_2503'
7475

7576

@@ -594,6 +595,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
594595
vision_tower=['model.vision_tower', 'model.qwen2vl_vision_tower'],
595596
))
596597

598+
register_model_arch(
599+
MultiModelKeys(
600+
MLLMModelArch.gemma3n,
601+
language_model='model.language_model',
602+
aligner=['model.embed_vision', 'model.embed_audio'],
603+
vision_tower=['model.vision_tower', 'model.audio_tower'],
604+
))
605+
597606

598607
def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
599608
return MODEL_ARCH_MAPPING.get(arch_name)

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class MLLMTemplateType:
188188
megrez_omni = 'megrez_omni'
189189
valley = 'valley'
190190
gemma3_vision = 'gemma3_vision'
191+
gemma3n = 'gemma3n'
191192
mistral_2503 = 'mistral_2503'
192193

193194

swift/llm/template/template/gemma.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..register import TemplateMeta, register_template
1212
from ..template_inputs import StdTemplateInputs
1313
from ..utils import Context, Prompt, findall
14+
from ..vision_utils import load_audio
1415

1516

1617
@dataclass
@@ -129,3 +130,102 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
129130

130131

131132
register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3_vision, template_cls=Gemma3VisionTemplate))
133+
134+
135+
class Gemma3nTemplate(Gemma3Template):
136+
boi_token_id = 255999
137+
boa_token_id = 256000
138+
placeholder_tokens = ['<start_of_image>', '<start_of_audio>']
139+
140+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
141+
inputs: StdTemplateInputs) -> List[Context]:
142+
if media_type == 'image':
143+
return ['<start_of_image>']
144+
elif media_type == 'audio':
145+
inputs.audios[index] = load_audio(inputs.audios[index], self.processor.feature_extractor.sampling_rate)
146+
return ['<start_of_audio>']
147+
else:
148+
raise ValueError(f'Unsupported media type: {media_type}. Supported types are: image, audio')
149+
150+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
151+
from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs
152+
153+
# Input validation
154+
if not inputs.images and not inputs.audios and not inputs.messages:
155+
raise ValueError('Provide at least one of `images`, `audios`, or `messages`.')
156+
157+
encoded = super()._encode(inputs)
158+
processor = self.processor
159+
input_ids = encoded['input_ids']
160+
labels = encoded['labels']
161+
162+
# Initialize token_type_ids and other outputs
163+
array_ids = np.array(input_ids)
164+
mm_token_type_ids = np.zeros_like(input_ids)
165+
166+
# Handle images
167+
if inputs.images:
168+
idx_list = findall(input_ids, self.boi_token_id)
169+
img_tokens = self._tokenize(processor.full_image_sequence)
170+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
171+
172+
# Process images
173+
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {})
174+
image_inputs = processor.image_processor(inputs.images, **processor_kwargs)
175+
image_inputs['pixel_values'] = torch.as_tensor(np.array(image_inputs['pixel_values']))
176+
if 'num_crops' in image_inputs:
177+
image_inputs.pop('num_crops')
178+
encoded.update(image_inputs)
179+
180+
# Handle audios
181+
if inputs.audios:
182+
audio_idx_list = findall(input_ids, self.boa_token_id)
183+
if audio_idx_list:
184+
# Get audio token sequence from processor
185+
audio_tokens = self._tokenize(processor.full_audio_sequence)
186+
input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens)
187+
188+
# Process audios
189+
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {})
190+
audio_inputs = processor.feature_extractor(inputs.audios, **processor_kwargs)
191+
192+
if 'input_features' in audio_inputs:
193+
audio_inputs['input_features'] = torch.tensor(audio_inputs['input_features']).to(
194+
self.model_info.torch_dtype)
195+
if 'input_features_mask' in audio_inputs:
196+
audio_inputs['input_features_mask'] = torch.tensor(audio_inputs['input_features_mask'])
197+
encoded.update(audio_inputs)
198+
199+
# Update array_ids after token extension
200+
array_ids = np.array(input_ids)
201+
mm_token_type_ids = np.zeros_like(input_ids)
202+
203+
if hasattr(processor, 'image_token_id') and processor.image_token_id is not None:
204+
mm_token_type_ids[array_ids == processor.image_token_id] = 1
205+
206+
if hasattr(processor, 'audio_token_id') and processor.audio_token_id is not None:
207+
mm_token_type_ids[array_ids == processor.audio_token_id] = 3
208+
209+
encoded['token_type_ids'] = mm_token_type_ids.tolist()
210+
encoded['input_ids'] = input_ids
211+
encoded['labels'] = labels
212+
213+
return encoded
214+
215+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
216+
"""Handle multimodal data collation for Gemma3n, including audio features"""
217+
res = super()._data_collator_mm_data(batch)
218+
219+
# Handle audio features like other templates do
220+
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
221+
input_features_mask = [b['input_features_mask'] for b in batch if b.get('input_features_mask') is not None]
222+
223+
if input_features:
224+
res['input_features'] = torch.concat(input_features)
225+
if input_features_mask:
226+
res['input_features_mask'] = torch.concat(input_features_mask)
227+
228+
return res
229+
230+
231+
register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3n, template_cls=Gemma3nTemplate))

tests/test_align/test_template/test_audio.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def test_qwen2_5_omni():
6565
assert response == response2
6666

6767

68+
def test_gemma3n():
69+
pt_engine = PtEngine('google/gemma-3n-E4B-it')
70+
messages = [{'role': 'user', 'content': '<audio>Transcribe this audio and complete the statement'}]
71+
audios = ['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav']
72+
response = _infer_model(pt_engine, messages=messages, audios=audios)
73+
pt_engine.default_template.template_backend = 'jinja'
74+
response2 = _infer_model(pt_engine, messages=messages, audios=audios)
75+
assert response == response2
76+
77+
6878
if __name__ == '__main__':
6979
from swift.llm import PtEngine, RequestConfig
7080
from swift.utils import get_logger, seed_everything
@@ -73,4 +83,5 @@ def test_qwen2_5_omni():
7383
# test_qwen2_audio()
7484
# test_xcomposer2d5_ol()
7585
# test_step_audio_chat()
76-
test_qwen2_5_omni()
86+
# test_qwen2_5_omni()
87+
test_gemma3n()

tests/test_align/test_template/test_vision.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,19 @@ def test_glm4_1v():
578578
assert response == response2
579579

580580

581+
def test_gemma3n():
582+
pt_engine = PtEngine('google/gemma-3n-E2B-it')
583+
messages = [{'role': 'user', 'content': '<image><image>What is the difference between the two images?'}]
584+
images = [
585+
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png',
586+
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'
587+
]
588+
response = _infer_model(pt_engine, messages=messages, images=images)
589+
pt_engine.default_template.template_backend = 'jinja'
590+
response2 = _infer_model(pt_engine, messages=messages, images=images)
591+
assert response == response2
592+
593+
581594
if __name__ == '__main__':
582595
from swift.llm import PtEngine, RequestConfig
583596
from swift.utils import get_logger, seed_everything
@@ -632,4 +645,5 @@ def test_glm4_1v():
632645
# test_internvl3_9b()
633646
# test_kimi_vl()
634647
# test_kimi_vl_thinking()
635-
test_glm4_1v()
648+
# test_glm4_1v()
649+
test_gemma3n()

0 commit comments

Comments
 (0)