Skip to content

Commit 3497f86

Browse files
authored
[train] support qwen2.5-omni mixed data (#5513)
1 parent cf130fa commit 3497f86

File tree

3 files changed

+79
-5
lines changed

3 files changed

+79
-5
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
- 'last_round_with_ignore_empty_think': 在`'last_round'`的基础上,忽略空的`'<think>\n\n</think>\n\n'`损失计算。
102102
- 'react', 'hermes', 'qwen': 在`'default'`的基础上,将`tool_call`部分的loss权重调整为2。
103103
- sequence_parallel_size: 序列并行大小,默认是1。当前支持CPT/SFT/DPO/GRPO。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/ulysses/sequence_parallel.sh)
104-
- response_prefix: response的前缀字符,例如QwQ-32B将response_prefix设置为`'<think>\n'`。默认为None,根据模型自动设置。
104+
- response_prefix: response的前缀字符,例如QwQ-32B将response_prefix设置为`'<think>\n'`。默认为None,根据模型自动设置。(该参数只在推理时生效)
105105
- 注意:若对deepseek-r1/qwq模型使用不包含`<think>...</think>`的数据集进行训练,请加在推理训练后模型时额外传入`--response_prefix ''`
106106
- template_backend: 选择template后端,可选为'swift'、'jinja',默认为'swift'。如果使用jinja,则使用transformers的`apply_chat_template`
107107
- 注意:jinja的template后端只支持推理,不支持训练。
@@ -400,7 +400,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
400400
- 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh。注意:请使用"ms-swift>=3.6",关注[此PR](https://github.com/modelscope/ms-swift/pull/4838)
401401
- packing_length: packing的长度。默认为None,设置为max_length。
402402
- lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数在LLM训练中默认设置为False,而MLLM训练默认为True,节约内存。
403-
- 注意:若你要进行图像的数据增强,你需要将lazy_tokenize设置为True,并修改Template类中的encode方法。
403+
- 注意:若你要进行图像的数据增强,你需要将lazy_tokenize(或streaming)设置为True,并修改Template类中的encode方法。
404404
- cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize占用gpu时。默认为`[]`
405405
- 注意:cached_dataset支持`--packing`,但不支持`--lazy_tokenize``--streaming`
406406
- use_logits_to_keep: 通过在`forward`中根据labels传入logits_to_keep,减少无效logits的计算与存储,从而减少显存占用并加快训练速度。默认为None,进行自动选择。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Hints:
102102
- 'last_round_with_ignore_empty_think': Based on 'last_round', ignore the loss calculation for an empty `'<think>\n\n</think>\n\n'` block.
103103
- `'react'`, `'hermes'`, `'qwen'`: On top of `'default'`, set the loss weight of the `tool_call` part to 2.
104104
- sequence_parallel_size: Sequence parallelism size, default is 1. Currently supported in CPT/SFT/DPO/GRPO. The training script refers to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/ulysses/sequence_parallel.sh).
105-
- response_prefix: The prefix character for the response, for example, setting the response_prefix to `'<think>\n'` for QwQ-32B. The default is None, and it is automatically set according to the model.
105+
- response_prefix: The prefix character for the response, for example, setting the response_prefix to `'<think>\n'` for QwQ-32B. The default is None, and it is automatically set according to the model. (This parameter is only effective during inference.)
106106
- Note: If you are training the deepseek-r1/qwq model with a dataset that does not include `<think>...</think>`, please pass `--response_prefix ''` additionally when inferring after training.
107107
- template_backend: Selection of the template backend. Options are 'swift' and 'jinja', with 'swift' as the default. If using jinja, it applies transformer's `apply_chat_template`.
108108
- Note: The jinja template backend supports only inference, not training.
@@ -409,7 +409,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
409409
- Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh. Note: Please use "ms-swift>=3.6" and follow [this PR](https://github.com/modelscope/ms-swift/pull/4838).
410410
- packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length.
411411
- lazy_tokenize: Whether to use lazy tokenization. If set to False, all dataset samples are tokenized before training (for multimodal models, this includes reading images from disk). This parameter defaults to False for LLM training, and True for MLLM training, to save memory.
412-
- Note: If you want to perform image data augmentation, you need to set `lazy_tokenize` to True and modify the `encode` method in the Template class.
412+
- Note: If you want to perform image data augmentation, you need to set `lazy_tokenize` (or `streaming`) to True and modify the `encode` method in the Template class.
413413
- cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default: `[]`.
414414
- Note: cached_dataset supports `--packing` but does not support `--lazy_tokenize` or `--streaming`.
415415
- use_logits_to_keep: Pass `logits_to_keep` in the `forward` method based on labels to reduce the computation and storage of unnecessary logits, thereby reducing memory usage and accelerating training. The default is `None`, which enables automatic selection.

swift/llm/template/template/qwen.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import transformers
1010
from packaging import version
1111
from torch import nn
12+
from transformers.integrations import is_deepspeed_zero3_enabled
1213

1314
from swift.llm import get_packed_seq_params, to_device, to_float_dtype
1415
from swift.utils import get_env_args, is_deepspeed_enabled
@@ -592,7 +593,80 @@ def _get_new_tokens(i):
592593
return encoded
593594

594595
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
595-
return Template._post_encode(self, model, inputs)
596+
if not self.is_training:
597+
return inputs
598+
599+
input_ids = inputs['input_ids']
600+
pixel_values = inputs.get('pixel_values')
601+
pixel_values_videos = inputs.get('pixel_values_videos')
602+
image_grid_thw = inputs.get('image_grid_thw')
603+
video_grid_thw = inputs.get('video_grid_thw')
604+
input_features = inputs.get('input_features')
605+
feature_attention_mask = inputs.get('feature_attention_mask')
606+
607+
base_model = self.get_base_model(model)
608+
inputs_embeds = base_model.thinker.model.embed_tokens(input_ids)
609+
visual = model.thinker.visual
610+
dtype = visual.dtype
611+
thinker_config = model.config.thinker_config
612+
if pixel_values is None and pixel_values_videos is None: # plain-text
613+
if is_deepspeed_enabled():
614+
from PIL import Image
615+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
616+
media_inputs = self.processor.image_processor(images=images, return_tensors='pt')
617+
device = input_ids.device
618+
media_inputs = to_device(media_inputs, device)
619+
pixel_values = media_inputs['pixel_values'].type(dtype)
620+
image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
621+
inputs_embeds = inputs_embeds + image_embeds.mean() * 0.
622+
else:
623+
if pixel_values is None:
624+
pixel_values_mixed = pixel_values_videos
625+
grid_thw = video_grid_thw
626+
elif pixel_values_videos is None:
627+
pixel_values_mixed = pixel_values
628+
grid_thw = image_grid_thw
629+
else:
630+
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
631+
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
632+
pixel_values_mixed = pixel_values_mixed.type(dtype)
633+
mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
634+
if pixel_values is None:
635+
image_embeds = None
636+
video_embeds = mixed_embeds
637+
elif pixel_values_videos is None:
638+
image_embeds = mixed_embeds
639+
video_embeds = None
640+
else:
641+
merge_length = self.processor.image_processor.merge_size**2
642+
image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
643+
image_embeds = mixed_embeds[:image_tokens]
644+
video_embeds = mixed_embeds[image_tokens:]
645+
646+
if image_embeds is not None:
647+
image_mask = (input_ids == thinker_config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
648+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
649+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
650+
651+
if video_embeds is not None:
652+
video_mask = (input_ids == thinker_config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
653+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
654+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
655+
656+
if input_features is None:
657+
if is_deepspeed_enabled() and not is_deepspeed_zero3_enabled():
658+
# Note: ZeRO-3 still results in hangs; for audio training, please use ZeRO-2.
659+
input_features = input_ids.new_zeros([1, 128, 128], dtype=dtype)
660+
feature_attention_mask = input_ids.new_ones([1, 128], dtype=torch.bool)
661+
audio_embeds = model.thinker.get_audio_features(input_features, feature_attention_mask)
662+
inputs_embeds = inputs_embeds + audio_embeds.mean() * 0.
663+
else:
664+
audio_embeds = model.thinker.get_audio_features(input_features, feature_attention_mask)
665+
audio_mask = (input_ids == thinker_config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds)
666+
audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
667+
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_embeds)
668+
669+
return {'inputs_embeds': inputs_embeds}
596670

597671
def _get_position_ids(self, inputs: Dict[str, Any]):
598672
feature_attention_mask = inputs.get('feature_attention_mask')

0 commit comments

Comments
 (0)