Skip to content

Commit a354329

Browse files
committed
[megatron] use batched mrope (#6281)
1 parent 13253e9 commit a354329

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

docs/source/Customization/自定义数据集.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ alpaca格式:
215215
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>帮我打开谷歌浏览器"}, {"role": "assistant", "content": "Action: click(start_box='<bbox>')"}], "images": ["/xxx/x.jpg"], "objects": {"ref": [], "bbox": [[615, 226]]}}
216216
```
217217
该格式将自动转换数据集格式为对应模型的grounding任务格式,且选择对应模型的bbox归一化方式。该格式比通用格式多了objects字段,该字段包含的字段有:
218-
- ref: 用于替换`<ref-object>`。ref的长度需要与`<ref-object>`的数量一致。
219-
- bbox: 用于替换`<bbox>`。若bbox中每个box长度为2,则代表x和y坐标,若box长度为4,则代表2个点的x和y坐标。bbox的长度需要与`<bbox>`的数量一致。
218+
- ref: 用于替换messages中的`<ref-object>`。ref的长度需要与`<ref-object>`的数量一致。
219+
- bbox: 用于替换messages中的`<bbox>`。若bbox中每个box长度为2,则代表x和y坐标,若box长度为4,则代表2个点的x和y坐标。bbox的长度需要与`<bbox>`的数量一致。
220220
- 注意:`<ref-object>``<bbox>`并没有对应关系,ref和bbox各自替换各自的占位符。
221221
- bbox_type: 可选项为'real','norm1'。默认为'real',即bbox为真实bbox值。若是'norm1',则bbox已经归一化为0~1。
222222
- image_id: 通常用于多图grounding任务。该参数只有当bbox_type为'real'时生效,代表bbox对应的图片是第几张,用于缩放bbox。索引从0开始,默认全为第0张。image_id的数量需要和bbox的数量一致。例如:若bbox的长度为10,images的长度为2,那么image_id的长度需要是10,其值需要在`{0, 1}`集合内。

docs/source_en/Customization/Custom-dataset.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ When using this type of data, please note:
230230

231231
The format will automatically convert the dataset format to the corresponding model's grounding task format and select the appropriate model's bbox normalization method. Compared to the general format, this format includes an additional "objects" field, which contains the following subfields:
232232

233-
- ref: Used to replace `<ref-object>`. The length of `ref` should match the number of `<ref-object>` instances.
234-
- bbox: Used to replace `<bbox>`. If the length of each box in the bbox is 2, it represents the x and y coordinates. If the box length is 4, it represents the x and y coordinates of two points. The length of `bbox` should match the number of `<bbox>` instances.
233+
- ref: Used to replace the `<ref-object>` placeholder in messages. The length of `ref` should match the number of `<ref-object>` instances.
234+
- bbox: Used to replace the `<bbox>` placeholder in messages. If the length of each box in the bbox is 2, it represents the x and y coordinates. If the box length is 4, it represents the x and y coordinates of two points. The length of `bbox` should match the number of `<bbox>` instances.
235235
- Note: `<ref-object>` and `<bbox>` do not have a corresponding relationship; references and bounding boxes replace their own placeholders separately.
236236
- bbox_type: Optional values are 'real' and 'norm1'. The default is 'real', meaning the bbox represents the actual bounding box value. If set to 'norm1', the bbox is normalized to the range 0~1.
237237
- image_id: Typically used for multi-image grounding tasks. This parameter only takes effect when bbox_type is 'real', representing which image the bbox corresponds to, used for scaling the bbox. The index starts from 0, and defaults to all being the 0th image. The length of image_id needs to be consistent with the length of bbox. For example: if the length of bbox is 10 and the length of images is 2, then the length of image_id needs to be 10, with values within the set `{0, 1}`.

swift/megatron/init.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import concurrent.futures
3+
import logging
34
import os
45
import subprocess
56
import sys
@@ -619,10 +620,17 @@ def _apply_rotary_pos_emb_thd(
619620
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
620621
"""
621622
args = get_args()
622-
if args.position_embedding_type != 'mrope':
623+
cu_seqlens_for_batched = cu_seqlens
624+
use_batched_mrope = False
625+
if cp_group is not None:
626+
cp_size = cp_group.size()
627+
cu_seqlens_for_batched = cu_seqlens // cp_size
628+
use_batched_mrope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
629+
if args.position_embedding_type != 'mrope' and not use_batched_mrope:
630+
logger.warning_once('Using non-batched RoPE, which may affect performance.')
623631
return _origin_apply_rotary_pos_emb_thd(
624632
t,
625-
cu_seqlens,
633+
cu_seqlens_for_batched,
626634
freqs,
627635
rotary_interleaved=rotary_interleaved,
628636
multi_latent_attention=multi_latent_attention,
@@ -632,24 +640,20 @@ def _apply_rotary_pos_emb_thd(
632640

633641
if cp_group is None:
634642
raise ValueError('cp_group must be provided for THD format RoPE')
635-
cp_size = cp_group.size()
636-
cu_seqlens = cu_seqlens // cp_size
637-
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
638-
639-
return torch.cat([
640-
_apply_rotary_pos_emb_bshd(
641-
x.unsqueeze(1),
642-
f,
643-
rotary_interleaved=rotary_interleaved,
644-
multi_latent_attention=multi_latent_attention,
645-
mscale=mscale,
646-
) for x, f in zip(torch.split(t, seqlens), torch.split(freqs, seqlens))
647-
]).squeeze(1)
643+
644+
return _apply_rotary_pos_emb_bshd(
645+
t.unsqueeze(1),
646+
freqs,
647+
rotary_interleaved=rotary_interleaved,
648+
multi_latent_attention=multi_latent_attention,
649+
mscale=mscale,
650+
).squeeze(1)
648651

649652
rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd
650653

651654

652655
def _patch_megatron():
656+
logging_level = logging.root.level
653657
_patch_flash_attn()
654658
_patch_transformer_engine()
655659
_patch_TELinear()
@@ -660,6 +664,7 @@ def _patch_megatron():
660664
_patch_compile_helpers()
661665
_patch_build_train_valid_test_datasets()
662666
_patch_mrope()
667+
logging.root.setLevel(logging_level) # revert logger level
663668
from swift.megatron import tuners # patch lora
664669
try:
665670
_patch_torch_FileSystemReader()

0 commit comments

Comments
 (0)