Skip to content

Commit c14132c

Browse files
authored
[bugfix] fix megatron non-padding_free qwen3_vl cp (#7233)
1 parent 9213961 commit c14132c

File tree

8 files changed

+18
-25
lines changed

8 files changed

+18
-25
lines changed

docs/source/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数,
466466
- add_version: 在output_dir上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。
467467
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。**如果是断网环境,请设置为False**
468468
- 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。
469-
- 🔥packing: 将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attn_impl flash_attn` 时,可确保packed样本内的不同序列之间相互独立,互不可见。该参数默认为`False`,目前支持 CPT/SFT/DPO/KTO/GKD。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**
469+
- 🔥packing: 使用`padding_free`的方式将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attn_impl flash_attn` 时,可确保packed样本内的不同序列之间相互独立,互不可见。该参数默认为`False`,目前支持 CPT/SFT/DPO/KTO/GKD。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**
470470
- "ms-swift>=3.12"新支持了embedding/reranker/seq_cls任务的packing。
471471
- packing_length: packing的长度。默认为None,设置为max_length。
472472
- packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效)。通常不需要修改该值,packing速度远快于tokenize速度。

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
300300
- 提示:在日志中打印的"learning rate"为llm的学习率。
301301
- aligner_lr: 当训练多模态大模型时,该参数指定aligner的学习率,默认为None,等于learning_rate。
302302
- gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。
303-
- 🔥packing: 将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attention_backend flash` 时,可确保packed样本内的不同序列之间相互独立,互不可见(除Qwen3-Next,因为含有linear-attention)。该参数默认为`False`。Megatron-SWIFT的所有训练任务都支持该参数。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**
303+
- 🔥packing: 使用`padding_free`的方式将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attention_backend flash` 时,可确保packed样本内的不同序列之间相互独立,互不可见(除Qwen3-Next,因为含有linear-attention)。该参数默认为`False`。Megatron-SWIFT的所有训练任务都支持该参数。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**
304304
- packing_length: packing的长度。默认为None,设置为max_length。
305305
- packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效)。通常不需要修改该值,packing速度远快于tokenize速度。
306306
- streaming: 流式读取并处理数据集,默认False。(流式数据集的随机并不彻底,可能导致loss波动剧烈。)

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
476476
- add_version: Add directory to output_dir with `'<version>-<timestamp>'` to prevent weight overwrite, default is True.
477477
- check_model: Check local model files for corruption or modification and give a prompt, default is True. **If in an offline environment, please set to False.**
478478
- 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively.
479-
- 🔥packing: Packs data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attn_impl flash_attn`, it ensures that different sequences within packed samples remain independent and invisible to each other. This parameter defaults to `False` and currently supports CPT/SFT/DPO/KTO/GKD. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**.
479+
- 🔥packing: Use the `padding_free` method to pack data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attn_impl flash_attn`, it ensures that different sequences within packed samples remain independent and invisible to each other. This parameter defaults to `False` and currently supports CPT/SFT/DPO/KTO/GKD. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**.
480480
- "ms-swift>=3.12" has newly added support for packing in embedding/reranker/seq_cls tasks.
481481
- packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length.
482482
- packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing). Usually there is no need to modify this value, as packing speed is much faster than tokenization speed.

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa
319319
- Note: The "learning rate" printed in the logs is the learning rate of the LLM.
320320
- aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`.
321321
- gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled.
322-
- 🔥packing: Packs data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attention_backend flash`, it ensures that different sequences within packed samples remain independent and invisible to each other (except for Qwen3-Next, which contains linear-attention). This parameter defaults to `False`. All training tasks in Megatron-SWIFT support this parameter. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**.
322+
- 🔥packing: Use the `padding_free` method to pack data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attention_backend flash`, it ensures that different sequences within packed samples remain independent and invisible to each other (except for Qwen3-Next, which contains linear-attention). This parameter defaults to `False`. All training tasks in Megatron-SWIFT support this parameter. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**.
323323
- packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length.
324324
- packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing). Usually there is no need to modify this value, as packing speed is much faster than tokenization speed.
325325
- streaming: Stream data loading and processing, default is False. (The shuffling of streaming datasets is not thorough, which may lead to severe loss fluctuations.)

swift/megatron/init.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -670,10 +670,8 @@ def _write_item(self, *args, **kwargs):
670670

671671
def _patch_mrope():
672672
from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding
673-
from megatron.core import parallel_state
674673
import megatron.core
675-
from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank,
676-
_apply_rotary_pos_emb_bshd)
674+
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
677675
from megatron.core.models.common.embeddings import rope_utils
678676
from megatron.training import get_args
679677

@@ -729,10 +727,6 @@ def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = Fal
729727

730728
# shape (seq_length, bs, 1, 2 * dim)
731729
emb = emb[..., None, :].transpose(0, 1).contiguous()
732-
if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq:
733-
# slice rotary_pos_emb along sequence dimension and select the parition of the current
734-
# CP rank
735-
emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group())
736730
return emb
737731

738732
MultimodalRotaryEmbedding.forward = forward

swift/megatron/model/mm_gpt/qwen3_vl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
122122
# compat cp
123123
args = get_args()
124124
if args.context_parallel_size > 1:
125-
assert packed_seq_params is not None
126125
device = visual_pos_masks.device
127126
cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device)
128127
cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device)
129-
cp_mask = split_cp_inputs(cp_mask, packed_seq_params.cu_seqlens_q, 0)
130-
visual_pos_masks = split_cp_inputs(visual_pos_masks, packed_seq_params.cu_seqlens_q, 0)
128+
cu_seqlens = getattr(packed_seq_params, 'cu_seqlens_q', None)
129+
cp_mask = split_cp_inputs(cp_mask, cu_seqlens, 0)
130+
visual_pos_masks = split_cp_inputs(visual_pos_masks, cu_seqlens, 0)
131131
deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]]
132132
# compat sp
133133
tp_world_size = parallel_state.get_tensor_model_parallel_world_size()

swift/megatron/model/mm_gpt_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(_self, input_):
6767
kwargs.update(res)
6868
res = inputs_embeds
6969
if args.context_parallel_size > 1:
70-
res = split_cp_inputs(res, packed_seq_params.cu_seqlens_q, 1)
70+
res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
7171
if reduce_scatter_embeddings:
7272
res = res.transpose(0, 1).contiguous()
7373
group_kwargs = {'group': _self.tp_group} if mcore_013 else {}

swift/megatron/trainers/utils.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from megatron.core.distributed import DistributedDataParallel as DDP
1414
from megatron.core.optimizer import ChainedOptimizer
1515
from megatron.core.packed_seq_params import PackedSeqParams
16-
from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank
1716
from megatron.training import get_args, get_wandb_writer
1817
from packaging import version
1918

@@ -86,17 +85,19 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams:
8685
qkv_format='thd')
8786

8887

89-
def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int):
90-
# TODO: compat bshd
88+
def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int):
9189
if dim < 0:
9290
dim = (dim + inputs.ndim) % inputs.ndim
9391
new_inputs = []
9492
cp_size = mpu.get_context_parallel_world_size()
9593
cp_rank = mpu.get_context_parallel_rank()
96-
for i in range(cu_seqlens.shape[0] - 1):
97-
slices = [slice(None)] * inputs.ndim
98-
slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1])
99-
val = inputs[tuple(slices)]
94+
for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)):
95+
if cu_seqlens is None:
96+
val = inputs
97+
else:
98+
slices = [slice(None)] * inputs.ndim
99+
slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1])
100+
val = inputs[tuple(slices)]
100101
view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:])
101102
val = val.view(view_shape)
102103
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
@@ -127,15 +128,13 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
127128
keys.append('input_ids')
128129

129130
packed_seq_params = batch.get('packed_seq_params')
130-
if packed_seq_params is None:
131-
return mcore_get_batch_on_this_cp_rank(batch)
132131
for key, val in batch.items():
133132
if key not in keys:
134133
continue
135134
if args.task_type == 'seq_cls' and key == 'labels':
136135
continue
137136
if val is not None:
138-
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1)
137+
batch[key] = split_cp_inputs(val, getattr(packed_seq_params, 'cu_seqlens_q', None), -1)
139138

140139
return batch
141140

0 commit comments

Comments
 (0)