Skip to content

[loss_scale] fix loss_scale when meeting <image>,<audio>,<video> #4922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_id
added_tokens_len += token_len - 1
return input_ids, labels

@staticmethod
def _extend_loss_scale(loss_scale: Optional[List[float]], replace_idx_list: List[int],
get_new_tokens: Callable[[int], List[int]]) -> Optional[List[float]]:
if loss_scale:
added_tokens_len = 0
for i, idx in enumerate(replace_idx_list):
new_tokens = get_new_tokens(i)
token_len = len(new_tokens)
scale_idx = loss_scale[idx + added_tokens_len]
loss_scale = loss_scale[:idx + added_tokens_len] + [scale_idx] * token_len + loss_scale[added_tokens_len
+ idx + 1:]
added_tokens_len += token_len - 1
return loss_scale

def forward_context(self, model, inputs):
return nullcontext()

Expand Down
4 changes: 3 additions & 1 deletion swift/llm/template/template/emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
images = inputs.images
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
image_tokens = self.processor.tokenize_image(images)
image_prompts = []
idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder))
Expand All @@ -179,7 +180,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:

# Insert image tokens into input_ids
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i])
return {'input_ids': input_ids, 'labels': labels}
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda i: image_prompts[i])
return {'input_ids': input_ids, 'labels': labels, 'loss_scale': loss_scale}


register_template(
Expand Down
8 changes: 7 additions & 1 deletion swift/llm/template/template/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if inputs.images:
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, self.boi_token_id)
img_tokens = self._tokenize(self.processor.full_image_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)

# TODO: customize
processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs']
Expand All @@ -126,6 +128,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded['input_ids'] = input_ids
encoded['pixel_values'] = image_inputs['pixel_values']
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
return encoded


Expand Down Expand Up @@ -158,6 +161,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
processor = self.processor
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)

# Initialize token_type_ids and other outputs
array_ids = np.array(input_ids)
Expand All @@ -168,6 +172,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
idx_list = findall(input_ids, self.boi_token_id)
img_tokens = self._tokenize(processor.full_image_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)

# Process images
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {})
Expand All @@ -184,6 +189,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
# Get audio token sequence from processor
audio_tokens = self._tokenize(processor.full_audio_sequence)
input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens)
loss_scale = self._extend_loss_scale(loss_scale, audio_idx_list, lambda _: audio_tokens)

# Process audios
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {})
Expand All @@ -209,7 +215,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded['token_type_ids'] = mm_token_type_ids.tolist()
encoded['input_ids'] = input_ids
encoded['labels'] = labels

encoded['loss_scale'] = loss_scale
return encoded

def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
input_ids = encoded['input_ids']
idx_list = findall(input_ids, -100)
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
images = inputs.images
if images:
has_video = bool(inputs.videos)
Expand All @@ -146,6 +147,7 @@ def _get_new_tokens(i):
return img_tokens

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded['pixel_values'] = pixel_values
return encoded

Expand Down
4 changes: 4 additions & 0 deletions swift/llm/template/template/kwai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
processor = self.processor
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)

images = inputs.images
videos = inputs.videos
for media_type in ['images', 'videos']:
Expand Down Expand Up @@ -82,10 +84,12 @@ def _get_new_tokens(i):
return [media_token] * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded.update(media_inputs)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
return encoded

def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
if images:
split_token = self._tokenize('\n')
input_ids, labels = encoded['input_ids'], encoded['labels']
loss_scale = encoded['loss_scale']
idx_list = findall(input_ids, -100)
media_inputs = self.processor(
text='\n'.join(['<|image|>'] * len(idx_list)),
Expand All @@ -137,6 +138,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
lambda i: splited_tokens[i])
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, lambda i: splited_tokens[i])
encoded['pixel_values'] = media_inputs['pixel_values']
return encoded

Expand Down
4 changes: 4 additions & 0 deletions swift/llm/template/template/megrez.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)

for mm_key in ['images', 'audios']:
mm_data = getattr(inputs, mm_key)
Expand Down Expand Up @@ -70,8 +71,11 @@ def _get_new_tokens(i):
return self._tokenize(padding[i])

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/microsoft.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
images_idx = findall(input_ids, -100)
audios_idx = findall(input_ids, -200)
text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios))
Expand All @@ -181,6 +182,7 @@ def _get_new_tokens(i):

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx,
_get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, images_idx + audios_idx, _get_new_tokens)
new_encoded.pop('attention_mask')
encoded.update(new_encoded)
return encoded
Expand Down
4 changes: 4 additions & 0 deletions swift/llm/template/template/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
use_image_id = False
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, -100)

image_processor = self.processor.image_processor
Expand All @@ -192,6 +193,8 @@ def _get_new_tokens(i):
return self.processor.encode(placeholder, add_special_tokens=False)

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

if inputs.images:
input_tensor_ids = torch.tensor(input_ids)
unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
Expand All @@ -211,6 +214,7 @@ def _get_new_tokens(i):
encoded = {
'input_ids': input_ids,
'labels': labels,
'loss_scale': loss_scale,
'image_bound': image_bound,
'pixel_values': image_inputs['pixel_values'],
'tgt_sizes': image_inputs['tgt_sizes']
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
images = inputs.images
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, self.image_token)
if idx_list:
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
Expand All @@ -45,6 +46,7 @@ def _get_new_tokens(i):
return processor.encode(replace_str, add_special_tokens=False)

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

return encoded

Expand Down
4 changes: 4 additions & 0 deletions swift/llm/template/template/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
media_token = self._tokenize('<|media_pad|>')[0]
idx_list = findall(input_ids, media_token)
if inputs.images:
Expand All @@ -50,6 +51,9 @@ def _get_new_tokens(i):
return [media_token] * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

encoded['loss_scale'] = loss_scale
encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded.update(image_inputs)
Expand Down
4 changes: 4 additions & 0 deletions swift/llm/template/template/mplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
cut_enable = not videos
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, -100)
processor = self.processor
encoded = {}
Expand All @@ -108,6 +109,8 @@ def _get_new_tokens(i):
return token_list

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

image_token_idx = torch.tensor(findall(input_ids, image_token_list))
if self.version == '241101':
media_offset = image_token_idx
Expand All @@ -121,6 +124,7 @@ def _get_new_tokens(i):
})
encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/template/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
images = inputs.images
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, 10)
if idx_list:
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
Expand All @@ -37,6 +38,7 @@ def _get_new_tokens(i):
return img_tokens

encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

return encoded

Expand Down
8 changes: 8 additions & 0 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
processor = self.processor
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
images = inputs.images
videos = inputs.videos
for media_type in ['images', 'videos']:
Expand Down Expand Up @@ -295,10 +296,12 @@ def _get_new_tokens(i):
return [media_token] * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
encoded.update(media_inputs)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
return encoded

def forward_context(self, model, inputs):
Expand Down Expand Up @@ -492,6 +495,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
input_ids = encoded['input_ids']
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
# audio
audio_token_id = self._tokenize('<|AUDIO|>')
idx_list = findall(input_ids, audio_token_id)
Expand All @@ -510,6 +514,7 @@ def _get_new_audio_tokens(i):
return audio_token_id * audio_lengths[i]

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_audio_tokens)

for media_type in ['image', 'video']:
token = f'<|{media_type.upper()}|>'
Expand Down Expand Up @@ -548,6 +553,7 @@ def _get_new_tokens_use_audio_in_video(i):

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
_get_new_tokens_use_audio_in_video)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens_use_audio_in_video)

else:

Expand All @@ -556,9 +562,11 @@ def _get_new_tokens(i):
return token_id * token_len

input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
encoded.update(media_inputs)
return encoded

Expand Down