|
11 | 11 | from ..register import TemplateMeta, register_template |
12 | 12 | from ..template_inputs import StdTemplateInputs |
13 | 13 | from ..utils import Context, Prompt, findall |
| 14 | +from ..vision_utils import load_audio |
14 | 15 |
|
15 | 16 |
|
16 | 17 | @dataclass |
@@ -129,3 +130,102 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
129 | 130 |
|
130 | 131 |
|
131 | 132 | register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3_vision, template_cls=Gemma3VisionTemplate)) |
| 133 | + |
| 134 | + |
| 135 | +class Gemma3nTemplate(Gemma3Template): |
| 136 | + boi_token_id = 255999 |
| 137 | + boa_token_id = 256000 |
| 138 | + placeholder_tokens = ['<start_of_image>', '<start_of_audio>'] |
| 139 | + |
| 140 | + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
| 141 | + inputs: StdTemplateInputs) -> List[Context]: |
| 142 | + if media_type == 'image': |
| 143 | + return ['<start_of_image>'] |
| 144 | + elif media_type == 'audio': |
| 145 | + inputs.audios[index] = load_audio(inputs.audios[index], self.processor.feature_extractor.sampling_rate) |
| 146 | + return ['<start_of_audio>'] |
| 147 | + else: |
| 148 | + raise ValueError(f'Unsupported media type: {media_type}. Supported types are: image, audio') |
| 149 | + |
| 150 | + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| 151 | + from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs |
| 152 | + |
| 153 | + # Input validation |
| 154 | + if not inputs.images and not inputs.audios and not inputs.messages: |
| 155 | + raise ValueError('Provide at least one of `images`, `audios`, or `messages`.') |
| 156 | + |
| 157 | + encoded = super()._encode(inputs) |
| 158 | + processor = self.processor |
| 159 | + input_ids = encoded['input_ids'] |
| 160 | + labels = encoded['labels'] |
| 161 | + |
| 162 | + # Initialize token_type_ids and other outputs |
| 163 | + array_ids = np.array(input_ids) |
| 164 | + mm_token_type_ids = np.zeros_like(input_ids) |
| 165 | + |
| 166 | + # Handle images |
| 167 | + if inputs.images: |
| 168 | + idx_list = findall(input_ids, self.boi_token_id) |
| 169 | + img_tokens = self._tokenize(processor.full_image_sequence) |
| 170 | + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens) |
| 171 | + |
| 172 | + # Process images |
| 173 | + processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {}) |
| 174 | + image_inputs = processor.image_processor(inputs.images, **processor_kwargs) |
| 175 | + image_inputs['pixel_values'] = torch.as_tensor(np.array(image_inputs['pixel_values'])) |
| 176 | + if 'num_crops' in image_inputs: |
| 177 | + image_inputs.pop('num_crops') |
| 178 | + encoded.update(image_inputs) |
| 179 | + |
| 180 | + # Handle audios |
| 181 | + if inputs.audios: |
| 182 | + audio_idx_list = findall(input_ids, self.boa_token_id) |
| 183 | + if audio_idx_list: |
| 184 | + # Get audio token sequence from processor |
| 185 | + audio_tokens = self._tokenize(processor.full_audio_sequence) |
| 186 | + input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens) |
| 187 | + |
| 188 | + # Process audios |
| 189 | + processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {}) |
| 190 | + audio_inputs = processor.feature_extractor(inputs.audios, **processor_kwargs) |
| 191 | + |
| 192 | + if 'input_features' in audio_inputs: |
| 193 | + audio_inputs['input_features'] = torch.tensor(audio_inputs['input_features']).to( |
| 194 | + self.model_info.torch_dtype) |
| 195 | + if 'input_features_mask' in audio_inputs: |
| 196 | + audio_inputs['input_features_mask'] = torch.tensor(audio_inputs['input_features_mask']) |
| 197 | + encoded.update(audio_inputs) |
| 198 | + |
| 199 | + # Update array_ids after token extension |
| 200 | + array_ids = np.array(input_ids) |
| 201 | + mm_token_type_ids = np.zeros_like(input_ids) |
| 202 | + |
| 203 | + if hasattr(processor, 'image_token_id') and processor.image_token_id is not None: |
| 204 | + mm_token_type_ids[array_ids == processor.image_token_id] = 1 |
| 205 | + |
| 206 | + if hasattr(processor, 'audio_token_id') and processor.audio_token_id is not None: |
| 207 | + mm_token_type_ids[array_ids == processor.audio_token_id] = 3 |
| 208 | + |
| 209 | + encoded['token_type_ids'] = mm_token_type_ids.tolist() |
| 210 | + encoded['input_ids'] = input_ids |
| 211 | + encoded['labels'] = labels |
| 212 | + |
| 213 | + return encoded |
| 214 | + |
| 215 | + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| 216 | + """Handle multimodal data collation for Gemma3n, including audio features""" |
| 217 | + res = super()._data_collator_mm_data(batch) |
| 218 | + |
| 219 | + # Handle audio features like other templates do |
| 220 | + input_features = [b['input_features'] for b in batch if b.get('input_features') is not None] |
| 221 | + input_features_mask = [b['input_features_mask'] for b in batch if b.get('input_features_mask') is not None] |
| 222 | + |
| 223 | + if input_features: |
| 224 | + res['input_features'] = torch.concat(input_features) |
| 225 | + if input_features_mask: |
| 226 | + res['input_features_mask'] = torch.concat(input_features_mask) |
| 227 | + |
| 228 | + return res |
| 229 | + |
| 230 | + |
| 231 | +register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3n, template_cls=Gemma3nTemplate)) |
0 commit comments