Skip to content

Commit 7b8de20

Browse files
CrownStar7Jintao-Huang
authored andcommitted
[loss_scale] fix multimodal loss_scale, when meeting <image>,<audio>,<video> (#4922)
1 parent c733573 commit 7b8de20

File tree

14 files changed

+62
-2
lines changed

14 files changed

+62
-2
lines changed

swift/llm/template/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,20 @@ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_id
309309
added_tokens_len += token_len - 1
310310
return input_ids, labels
311311

312+
@staticmethod
313+
def _extend_loss_scale(loss_scale: Optional[List[float]], replace_idx_list: List[int],
314+
get_new_tokens: Callable[[int], List[int]]) -> Optional[List[float]]:
315+
if loss_scale:
316+
added_tokens_len = 0
317+
for i, idx in enumerate(replace_idx_list):
318+
new_tokens = get_new_tokens(i)
319+
token_len = len(new_tokens)
320+
scale_idx = loss_scale[idx + added_tokens_len]
321+
loss_scale = loss_scale[:idx + added_tokens_len] + [scale_idx] * token_len + loss_scale[added_tokens_len
322+
+ idx + 1:]
323+
added_tokens_len += token_len - 1
324+
return loss_scale
325+
312326
def forward_context(self, model, inputs):
313327
return nullcontext()
314328

swift/llm/template/template/emu3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
165165
images = inputs.images
166166
input_ids = encoded['input_ids']
167167
labels = encoded['labels']
168+
loss_scale = encoded.get('loss_scale', None)
168169
image_tokens = self.processor.tokenize_image(images)
169170
image_prompts = []
170171
idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder))
@@ -179,7 +180,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
179180

180181
# Insert image tokens into input_ids
181182
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i])
182-
return {'input_ids': input_ids, 'labels': labels}
183+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda i: image_prompts[i])
184+
return {'input_ids': input_ids, 'labels': labels, 'loss_scale': loss_scale}
183185

184186

185187
register_template(

swift/llm/template/template/gemma.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
109109
if inputs.images:
110110
input_ids = encoded['input_ids']
111111
labels = encoded['labels']
112+
loss_scale = encoded.get('loss_scale', None)
112113
idx_list = findall(input_ids, self.boi_token_id)
113114
img_tokens = self._tokenize(self.processor.full_image_sequence)
114115
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
116+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)
115117

116118
# TODO: customize
117119
processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs']
@@ -126,6 +128,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
126128
encoded['input_ids'] = input_ids
127129
encoded['pixel_values'] = image_inputs['pixel_values']
128130
encoded['labels'] = labels
131+
encoded['loss_scale'] = loss_scale
129132
return encoded
130133

131134

@@ -158,6 +161,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
158161
processor = self.processor
159162
input_ids = encoded['input_ids']
160163
labels = encoded['labels']
164+
loss_scale = encoded.get('loss_scale', None)
161165

162166
# Initialize token_type_ids and other outputs
163167
array_ids = np.array(input_ids)
@@ -168,6 +172,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
168172
idx_list = findall(input_ids, self.boi_token_id)
169173
img_tokens = self._tokenize(processor.full_image_sequence)
170174
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
175+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, lambda _: img_tokens)
171176

172177
# Process images
173178
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('images_kwargs', {})
@@ -184,6 +189,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
184189
# Get audio token sequence from processor
185190
audio_tokens = self._tokenize(processor.full_audio_sequence)
186191
input_ids, labels = self._extend_tokens(input_ids, labels, audio_idx_list, lambda _: audio_tokens)
192+
loss_scale = self._extend_loss_scale(loss_scale, audio_idx_list, lambda _: audio_tokens)
187193

188194
# Process audios
189195
processor_kwargs = Gemma3nProcessorKwargs._defaults.get('audio_kwargs', {})
@@ -209,7 +215,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
209215
encoded['token_type_ids'] = mm_token_type_ids.tolist()
210216
encoded['input_ids'] = input_ids
211217
encoded['labels'] = labels
212-
218+
encoded['loss_scale'] = loss_scale
213219
return encoded
214220

215221
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:

swift/llm/template/template/internvl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
121121
input_ids = encoded['input_ids']
122122
idx_list = findall(input_ids, -100)
123123
labels = encoded['labels']
124+
loss_scale = encoded.get('loss_scale', None)
124125
images = inputs.images
125126
if images:
126127
has_video = bool(inputs.videos)
@@ -146,6 +147,7 @@ def _get_new_tokens(i):
146147
return img_tokens
147148

148149
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
150+
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
149151
encoded['pixel_values'] = pixel_values
150152
return encoded
151153

swift/llm/template/template/kwai.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
5454
processor = self.processor
5555
input_ids = encoded['input_ids']
5656
labels = encoded['labels']
57+
loss_scale = encoded.get('loss_scale', None)
58+
5759
images = inputs.images
5860
videos = inputs.videos
5961
for media_type in ['images', 'videos']:
@@ -82,10 +84,12 @@ def _get_new_tokens(i):
8284
return [media_token] * token_len
8385

8486
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
87+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
8588
encoded.update(media_inputs)
8689

8790
encoded['input_ids'] = input_ids
8891
encoded['labels'] = labels
92+
encoded['loss_scale'] = loss_scale
8993
return encoded
9094

9195
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:

swift/llm/template/template/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
127127
if images:
128128
split_token = self._tokenize('\n')
129129
input_ids, labels = encoded['input_ids'], encoded['labels']
130+
loss_scale = encoded['loss_scale']
130131
idx_list = findall(input_ids, -100)
131132
media_inputs = self.processor(
132133
text='\n'.join(['<|image|>'] * len(idx_list)),
@@ -137,6 +138,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
137138

138139
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
139140
lambda i: splited_tokens[i])
141+
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, lambda i: splited_tokens[i])
140142
encoded['pixel_values'] = media_inputs['pixel_values']
141143
return encoded
142144

swift/llm/template/template/megrez.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
4040
encoded = super()._encode(inputs)
4141
input_ids = encoded['input_ids']
4242
labels = encoded['labels']
43+
loss_scale = encoded.get('loss_scale', None)
4344

4445
for mm_key in ['images', 'audios']:
4546
mm_data = getattr(inputs, mm_key)
@@ -70,8 +71,11 @@ def _get_new_tokens(i):
7071
return self._tokenize(padding[i])
7172

7273
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
74+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
75+
7376
encoded['input_ids'] = input_ids
7477
encoded['labels'] = labels
78+
encoded['loss_scale'] = loss_scale
7579
return encoded
7680

7781
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:

swift/llm/template/template/microsoft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
169169
encoded = super()._encode(inputs)
170170
input_ids = encoded['input_ids']
171171
labels = encoded['labels']
172+
loss_scale = encoded.get('loss_scale', None)
172173
images_idx = findall(input_ids, -100)
173174
audios_idx = findall(input_ids, -200)
174175
text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios))
@@ -181,6 +182,7 @@ def _get_new_tokens(i):
181182

182183
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx,
183184
_get_new_tokens)
185+
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, images_idx + audios_idx, _get_new_tokens)
184186
new_encoded.pop('attention_mask')
185187
encoded.update(new_encoded)
186188
return encoded

swift/llm/template/template/minicpm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
179179
use_image_id = False
180180
input_ids = encoded['input_ids']
181181
labels = encoded['labels']
182+
loss_scale = encoded.get('loss_scale', None)
182183
idx_list = findall(input_ids, -100)
183184

184185
image_processor = self.processor.image_processor
@@ -192,6 +193,8 @@ def _get_new_tokens(i):
192193
return self.processor.encode(placeholder, add_special_tokens=False)
193194

194195
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
196+
loss_scale = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
197+
195198
if inputs.images:
196199
input_tensor_ids = torch.tensor(input_ids)
197200
unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
@@ -211,6 +214,7 @@ def _get_new_tokens(i):
211214
encoded = {
212215
'input_ids': input_ids,
213216
'labels': labels,
217+
'loss_scale': loss_scale,
214218
'image_bound': image_bound,
215219
'pixel_values': image_inputs['pixel_values'],
216220
'tgt_sizes': image_inputs['tgt_sizes']

swift/llm/template/template/mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
2626
images = inputs.images
2727
input_ids = encoded['input_ids']
2828
labels = encoded['labels']
29+
loss_scale = encoded.get('loss_scale', None)
2930
idx_list = findall(input_ids, self.image_token)
3031
if idx_list:
3132
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
@@ -45,6 +46,7 @@ def _get_new_tokens(i):
4546
return processor.encode(replace_str, add_special_tokens=False)
4647

4748
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
49+
encoded['loss_scale'] = self._extend_loss_scale(loss_scale, idx_list, _get_new_tokens)
4850

4951
return encoded
5052

0 commit comments

Comments
 (0)