Skip to content

Commit 1347cdd

Browse files
committed
[train] fix channel_loss (qwen2.5-vl & packing) (#4941)
1 parent b66f661 commit 1347cdd

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

examples/train/plugins/channel_loss.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# {"role": "user", "content": "What color do you like?"},
99
# {"role": "assistant", "content": "I like blue."}
1010
# ]}
11+
12+
# channel_loss is compatible with padding-free and packing.
1113
CUDA_VISIBLE_DEVICES=0 \
1214
swift sft \
1315
--model Qwen/Qwen2.5-0.5B-Instruct \

swift/trainers/trainers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def prediction_step(
237237
**gen_kwargs,
238238
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
239239
if not self.args.predict_with_generate or prediction_loss_only:
240+
inputs['_position_ids'] = inputs.get('position_ids')
240241
with self.template.forward_context(self.model, inputs):
241242
return super().prediction_step(
242243
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
@@ -277,15 +278,19 @@ def _prepare_inputs(self, inputs):
277278
compute_loss_func = get_loss_func('loss_scale')
278279

279280
sample_channels = inputs.pop('channel', None)
280-
if sample_channels is not None and self.args.channels is not None:
281+
position_ids = inputs.pop('_position_ids', None)
282+
if self.args.channels is not None:
283+
assert sample_channels is not None, f'sample_channels: {sample_channels}'
281284
state = self.state
282285
setattr(state, 'local_step', getattr(state, 'local_step', 0))
283286
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))
284287

285288
loss_kwargs['sample_channels'] = sample_channels
286289
loss_kwargs['trainer'] = self
287-
if inputs.get('position_ids') is not None:
288-
loss_kwargs['position_ids'] = inputs['position_ids']
290+
if position_ids is None:
291+
position_ids = inputs.get('position_ids')
292+
if position_ids is not None:
293+
loss_kwargs['position_ids'] = position_ids
289294

290295
use_logits_to_keep = self.get_use_logits_to_keep('labels' in inputs and self.label_smoother is None
291296
and compute_loss_func is None)
@@ -352,5 +357,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
352357
return (loss, outputs) if return_outputs else loss
353358

354359
def training_step(self, model, inputs, *args, **kwargs):
360+
inputs['_position_ids'] = inputs.get('position_ids')
355361
with self.template.forward_context(self.model, inputs):
356362
return super().training_step(model, inputs, *args, **kwargs)

tests/train/test_channel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ def test_channel():
1010
model='Qwen/Qwen2.5-VL-7B-Instruct',
1111
dataset=['channel.jsonl#1000'],
1212
split_dataset_ratio=0.01,
13+
packing=True,
14+
max_length=128,
1315
channels=['aaa', 'abc'],
14-
loss_type='channel_loss'))
16+
attn_impl='flash_attn',
17+
loss_type='channel_loss',
18+
eval_steps=10))
1519

1620

1721
if __name__ == '__main__':

0 commit comments

Comments
 (0)