diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 70a06f924e..60d365f189 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -309,6 +309,20 @@ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_id added_tokens_len += token_len - 1 return input_ids, labels + @staticmethod + def _extend_loss_scale(loss_scale: Optional[List[float]], replace_idx_list: List[int], + get_new_tokens: Callable[[int], List[int]]) -> Optional[List[float]]: + if loss_scale: + added_tokens_len = 0 + for i, idx in enumerate(replace_idx_list): + new_tokens = get_new_tokens(i) + token_len = len(new_tokens) + scale_idx = loss_scale[idx + added_tokens_len] + loss_scale = loss_scale[:idx + added_tokens_len] + [scale_idx] * token_len + loss_scale[added_tokens_len + + idx + 1:] + added_tokens_len += token_len - 1 + return loss_scale + def forward_context(self, model, inputs): return nullcontext() diff --git a/swift/llm/template/template/emu3.py b/swift/llm/template/template/emu3.py index fb0ca1d42f..debfab7abc 100644 --- a/swift/llm/template/template/emu3.py +++ b/swift/llm/template/template/emu3.py @@ -165,6 +165,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: images = inputs.images input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) image_tokens = self.processor.tokenize_image(images) image_prompts = [] idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder)) @@ -179,7 +180,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # Insert image tokens into input_ids input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i]) - return {'input_ids': input_ids, 'labels': labels} + loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda i: image_prompts[i]) + return {'input_ids': input_ids, 'labels': labels, 'loss_scale': loss_scale} register_template( diff --git a/swift/llm/template/template/gemma.py b/swift/llm/template/template/gemma.py index 5e8689f576..bc4358039d 100644 --- a/swift/llm/template/template/gemma.py +++ b/swift/llm/template/template/gemma.py @@ -109,9 +109,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: if inputs.images: input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, self.boi_token_id) img_tokens = self._tokenize(self.processor.full_image_sequence) input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens) # TODO: customize processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs'] @@ -126,6 +128,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded['input_ids'] = input_ids encoded['pixel_values'] = image_inputs['pixel_values'] encoded['labels'] = labels + encoded['loss_scale'] = loss_scale return encoded @@ -158,6 +161,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: processor = self.processor input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) # Initialize token_type_ids and other outputs array_ids = np.array(input_ids) @@ -168,6 +172,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: idx_list = findall(input_ids, self.boi_token_id) img_tokens = self._tokenize(processor.full_image_sequence) input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens) # Process images processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {}) @@ -184,6 +189,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # Get audio token sequence from processor audio_tokens = self._tokenize(processor.full_audio_sequence) input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens) + loss_scale = self._extend_loss_scale(loss_scale, audio_idx_list, lambda _: audio_tokens) # Process audios processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {}) @@ -209,7 +215,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded['token_type_ids'] = mm_token_type_ids.tolist() encoded['input_ids'] = input_ids encoded['labels'] = labels - + encoded['loss_scale'] = loss_scale return encoded def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: diff --git a/swift/llm/template/template/internvl.py b/swift/llm/template/template/internvl.py index c0e3876571..0ea8d58d02 100644 --- a/swift/llm/template/template/internvl.py +++ b/swift/llm/template/template/internvl.py @@ -121,6 +121,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: input_ids = encoded['input_ids'] idx_list = findall(input_ids, -100) labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) images = inputs.images if images: has_video = bool(inputs.videos) @@ -146,6 +147,7 @@ def _get_new_tokens(i): return img_tokens encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) encoded['pixel_values'] = pixel_values return encoded diff --git a/swift/llm/template/template/kwai.py b/swift/llm/template/template/kwai.py index 7ed199fcaa..7d8d5d1255 100644 --- a/swift/llm/template/template/kwai.py +++ b/swift/llm/template/template/kwai.py @@ -54,6 +54,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: processor = self.processor input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) + images = inputs.images videos = inputs.videos for media_type in ['images', 'videos']: @@ -82,10 +84,12 @@ def _get_new_tokens(i): return [media_token] * token_len input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) encoded.update(media_inputs) encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['loss_scale'] = loss_scale return encoded def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/swift/llm/template/template/llama.py b/swift/llm/template/template/llama.py index b39fa79e58..4c248a3515 100644 --- a/swift/llm/template/template/llama.py +++ b/swift/llm/template/template/llama.py @@ -127,6 +127,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: if images: split_token = self._tokenize('\n') input_ids, labels = encoded['input_ids'], encoded['labels'] + loss_scale = encoded['loss_scale'] idx_list = findall(input_ids, -100) media_inputs = self.processor( text='\n'.join(['<|image|>'] * len(idx_list)), @@ -137,6 +138,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, lambda i: splited_tokens[i]) + encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, lambda i: splited_tokens[i]) encoded['pixel_values'] = media_inputs['pixel_values'] return encoded diff --git a/swift/llm/template/template/megrez.py b/swift/llm/template/template/megrez.py index 91b89e7406..9e5e8301ef 100644 --- a/swift/llm/template/template/megrez.py +++ b/swift/llm/template/template/megrez.py @@ -40,6 +40,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded = super()._encode(inputs) input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) for mm_key in ['images', 'audios']: mm_data = getattr(inputs, mm_key) @@ -70,8 +71,11 @@ def _get_new_tokens(i): return self._tokenize(padding[i]) input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) + encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['loss_scale'] = loss_scale return encoded def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/swift/llm/template/template/microsoft.py b/swift/llm/template/template/microsoft.py index 047df80366..a0ceeb1f26 100644 --- a/swift/llm/template/template/microsoft.py +++ b/swift/llm/template/template/microsoft.py @@ -169,6 +169,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded = super()._encode(inputs) input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) images_idx = findall(input_ids, -100) audios_idx = findall(input_ids, -200) text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios)) @@ -181,6 +182,7 @@ def _get_new_tokens(i): encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx, _get_new_tokens) + encoded['loss_scale'] = self._extend_loss_scale(loss_scale, images_idx + audios_idx, _get_new_tokens) new_encoded.pop('attention_mask') encoded.update(new_encoded) return encoded diff --git a/swift/llm/template/template/minicpm.py b/swift/llm/template/template/minicpm.py index 88e9566730..eb90e6532e 100644 --- a/swift/llm/template/template/minicpm.py +++ b/swift/llm/template/template/minicpm.py @@ -179,6 +179,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: use_image_id = False input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, -100) image_processor = self.processor.image_processor @@ -192,6 +193,8 @@ def _get_new_tokens(i): return self.processor.encode(placeholder, add_special_tokens=False) input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) + if inputs.images: input_tensor_ids = torch.tensor(input_ids) unk_token = self.processor.encode('', add_special_tokens=False)[0] @@ -211,6 +214,7 @@ def _get_new_tokens(i): encoded = { 'input_ids': input_ids, 'labels': labels, + 'loss_scale': loss_scale, 'image_bound': image_bound, 'pixel_values': image_inputs['pixel_values'], 'tgt_sizes': image_inputs['tgt_sizes'] diff --git a/swift/llm/template/template/mistral.py b/swift/llm/template/template/mistral.py index cbea49d34d..e85873fcf5 100644 --- a/swift/llm/template/template/mistral.py +++ b/swift/llm/template/template/mistral.py @@ -26,6 +26,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: images = inputs.images input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, self.image_token) if idx_list: image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') @@ -45,6 +46,7 @@ def _get_new_tokens(i): return processor.encode(replace_str, add_special_tokens=False) encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) return encoded diff --git a/swift/llm/template/template/moonshot.py b/swift/llm/template/template/moonshot.py index 770ab6179d..c531ab8c86 100644 --- a/swift/llm/template/template/moonshot.py +++ b/swift/llm/template/template/moonshot.py @@ -37,6 +37,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded = super()._encode(inputs) input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) media_token = self._tokenize('<|media_pad|>')[0] idx_list = findall(input_ids, media_token) if inputs.images: @@ -50,6 +51,9 @@ def _get_new_tokens(i): return [media_token] * token_len input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) + + encoded['loss_scale'] = loss_scale encoded['input_ids'] = input_ids encoded['labels'] = labels encoded.update(image_inputs) diff --git a/swift/llm/template/template/mplug.py b/swift/llm/template/template/mplug.py index ace1ebbf61..c6fcbb0048 100644 --- a/swift/llm/template/template/mplug.py +++ b/swift/llm/template/template/mplug.py @@ -91,6 +91,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: cut_enable = not videos input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, -100) processor = self.processor encoded = {} @@ -108,6 +109,8 @@ def _get_new_tokens(i): return token_list input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) + image_token_idx = torch.tensor(findall(input_ids, image_token_list)) if self.version == '241101': media_offset = image_token_idx @@ -121,6 +124,7 @@ def _get_new_tokens(i): }) encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['loss_scale'] = loss_scale return encoded def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/swift/llm/template/template/pixtral.py b/swift/llm/template/template/pixtral.py index 5a8acf7e7d..8585ce5753 100644 --- a/swift/llm/template/template/pixtral.py +++ b/swift/llm/template/template/pixtral.py @@ -18,6 +18,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: images = inputs.images input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, 10) if idx_list: image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') @@ -37,6 +38,7 @@ def _get_new_tokens(i): return img_tokens encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) return encoded diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index 52695a2f90..7b0958fc38 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -265,6 +265,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: processor = self.processor input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) images = inputs.images videos = inputs.videos for media_type in ['images', 'videos']: @@ -295,10 +296,12 @@ def _get_new_tokens(i): return [media_token] * token_len input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) encoded.update(media_inputs) encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['loss_scale'] = loss_scale return encoded def forward_context(self, model, inputs): @@ -492,6 +495,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype) input_ids = encoded['input_ids'] labels = encoded['labels'] + loss_scale = encoded.get('loss_scale', None) # audio audio_token_id = self._tokenize('<|AUDIO|>') idx_list = findall(input_ids, audio_token_id) @@ -510,6 +514,7 @@ def _get_new_audio_tokens(i): return audio_token_id * audio_lengths[i] input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_audio_tokens) for media_type in ['image', 'video']: token = f'<|{media_type.upper()}|>' @@ -548,6 +553,7 @@ def _get_new_tokens_use_audio_in_video(i): input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens_use_audio_in_video) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens_use_audio_in_video) else: @@ -556,9 +562,11 @@ def _get_new_tokens(i): return token_id * token_len input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) + loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens) encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['loss_scale'] = loss_scale encoded.update(media_inputs) return encoded