Skip to content

Commit 90c98a6

Browse files
authored
[megatron] support GLM-5 megatron (#8085)
1 parent d0eedd5 commit 90c98a6

File tree

14 files changed

+301
-17
lines changed

14 files changed

+301
-17
lines changed

docs/source/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@
414414
|[ZhipuAI/GLM-4.7](https://modelscope.cn/models/ZhipuAI/GLM-4.7)|glm4_moe|glm4_7|transformers>=4.54|✔|-|[zai-org/GLM-4.7](https://huggingface.co/zai-org/GLM-4.7)|
415415
|[ZhipuAI/GLM-4.7-FP8](https://modelscope.cn/models/ZhipuAI/GLM-4.7-FP8)|glm4_moe|glm4_7|transformers>=4.54|✘|-|[zai-org/GLM-4.7-FP8](https://huggingface.co/zai-org/GLM-4.7-FP8)|
416416
|[ZhipuAI/GLM-4.7-Flash](https://modelscope.cn/models/ZhipuAI/GLM-4.7-Flash)|glm4_moe_lite|glm4_7|transformers>=5.0.0.dev|✔|-|[zai-org/GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash)|
417-
|[ZhipuAI/GLM-5](https://modelscope.cn/models/ZhipuAI/GLM-5)|glm_moe_dsa|glm4_7|transformers>=5.2.0|✘|-|[zai-org/GLM-5](https://huggingface.co/zai-org/GLM-5)|
417+
|[ZhipuAI/GLM-5](https://modelscope.cn/models/ZhipuAI/GLM-5)|glm_moe_dsa|glm4_7|transformers>=5.2.0|✔|-|[zai-org/GLM-5](https://huggingface.co/zai-org/GLM-5)|
418418
|[ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)|glm_edge|chatglm4|transformers>=4.46|✘|-|[zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat)|
419419
|[ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)|glm_edge|chatglm4|transformers>=4.46|✘|-|[zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat)|
420420
|[codefuse-ai/CodeFuse-CodeGeeX2-6B](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeeX2-6B)|codefuse_codegeex2|codefuse|transformers<4.34|&#x2718;|coding|[codefuse-ai/CodeFuse-CodeGeeX2-6B](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeeX2-6B)|

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@
187187
- moe_pad_expert_input_to_capacity: 对每个专家(expert)的输入进行填充,使其长度与专家容量(expert capacity length)对齐,默认为False。该操作仅在设置了 `--moe_expert_capacity_factor` 参数后才生效。
188188
- moe_token_drop_policy: 可选为'probs', 'position'。默认为'probs'。
189189

190+
**DSA参数**
191+
- dsa_indexer_loss_coeff: DSA 索引器 KL 散度损失的系数。设置为 0 可禁用索引器损失。默认为None。
192+
- dsa_indexer_use_sparse_loss: 是否使用稀疏 DSA 索引器损失。如果为 True,索引器损失将使用 top-k 索引进行计算。默认为False。
193+
194+
190195
**MTP参数**
191196
- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。(需要"megatron-core>=0.14")
192197
- 注意:mtp_num_layers的值,将不自动从config.json获取,需手动设置。你可以参考config.json中的`num_nextn_predict_layers`字段填写该值。使用mcore-bridge时,将优先从safetensors文件中加载MTP权重,若无法找到,则进行随机初始化。(若要使用blockwise fp8 + mtp,请使用mcore>=0.15)

docs/source/Megatron-SWIFT/Quick-start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
6666
| python | >=3.9 | 3.10/3.11 | |
6767
| cuda | | cuda12 | |
6868
| torch | >=2.0 | 2.8.0 | |
69-
| transformer_engine | >=2.3 | 2.10.0 | |
69+
| transformer_engine | >=2.3 | 2.12.0 | |
7070
| apex | | 0.1 | |
7171
| megatron_core | >=0.12,<0.16 | 0.15 | |
7272
| flash_attn | | 2.8.3/3.0.0b1 | |

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ The table below introduces the models integrated with ms-swift:
415415
|[ZhipuAI/GLM-4.7](https://modelscope.cn/models/ZhipuAI/GLM-4.7)|glm4_moe|glm4_7|transformers>=4.54|&#x2714;|-|[zai-org/GLM-4.7](https://huggingface.co/zai-org/GLM-4.7)|
416416
|[ZhipuAI/GLM-4.7-FP8](https://modelscope.cn/models/ZhipuAI/GLM-4.7-FP8)|glm4_moe|glm4_7|transformers>=4.54|&#x2718;|-|[zai-org/GLM-4.7-FP8](https://huggingface.co/zai-org/GLM-4.7-FP8)|
417417
|[ZhipuAI/GLM-4.7-Flash](https://modelscope.cn/models/ZhipuAI/GLM-4.7-Flash)|glm4_moe_lite|glm4_7|transformers>=5.0.0.dev|&#x2714;|-|[zai-org/GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash)|
418-
|[ZhipuAI/GLM-5](https://modelscope.cn/models/ZhipuAI/GLM-5)|glm_moe_dsa|glm4_7|transformers>=5.2.0|&#x2718;|-|[zai-org/GLM-5](https://huggingface.co/zai-org/GLM-5)|
418+
|[ZhipuAI/GLM-5](https://modelscope.cn/models/ZhipuAI/GLM-5)|glm_moe_dsa|glm4_7|transformers>=5.2.0|&#x2714;|-|[zai-org/GLM-5](https://huggingface.co/zai-org/GLM-5)|
419419
|[ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat)|glm_edge|chatglm4|transformers>=4.46|&#x2718;|-|[zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat)|
420420
|[ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat)|glm_edge|chatglm4|transformers>=4.46|&#x2718;|-|[zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat)|
421421
|[codefuse-ai/CodeFuse-CodeGeeX2-6B](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeeX2-6B)|codefuse_codegeex2|codefuse|transformers<4.34|&#x2718;|coding|[codefuse-ai/CodeFuse-CodeGeeX2-6B](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeeX2-6B)|

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ For guidance on selecting parallelization strategies, please refer to the [Train
198198
- moe_pad_expert_input_to_capacity: Pad the input of each expert so that its length aligns with the expert capacity length. Default is `False`. This option only takes effect if `--moe_expert_capacity_factor` is set.
199199
- moe_token_drop_policy: Options are 'probs' and 'position'. Default is 'probs'.
200200

201+
**DSA Parameters**
202+
203+
- dsa_indexer_loss_coeff: Coefficient for the DSA indexer KL divergence loss. Set to 0 to disable indexer loss. Default is None.
204+
- dsa_indexer_use_sparse_loss: Whether to use sparse DSA indexer loss. If True, the indexer loss will be computed using the top-k indices. Default is False.
205+
201206

202207
**MTP Parameters**
203208
- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. (requires "megatron-core>=0.14")

docs/source_en/Megatron-SWIFT/Quick-start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Recommended Operating Environment:
6666
| python | >=3.9 | 3.10/3.11 | |
6767
| cuda | | cuda12 | |
6868
| torch | >=2.0 | 2.8.0 | |
69-
| transformer_engine | >=2.3 | 2.10.0 | |
69+
| transformer_engine | >=2.3 | 2.12.0 | |
7070
| apex | | 0.1 | |
7171
| megatron_core | >=0.12,<0.16 | 0.15 | |
7272
| flash_attn | | 2.8.3/3.0.0b1 | |

swift/megatron/arguments/megatron_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,10 @@ class MegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
528528
attn_impl: Optional[str] = None
529529
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
530530

531+
# dsa
532+
dsa_indexer_loss_coeff: Optional[float] = None
533+
dsa_indexer_use_sparse_loss: bool = False
534+
531535
# other
532536
check_model: bool = True
533537
torch_dtype: Optional[Union[torch.dtype, str]] = None

swift/megatron/init.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,22 @@ def forward(
136136
core_attn_out = self._checkpointed_attention_forward(
137137
query, key, value, attention_mask, packed_seq_params=packed_seq_params)
138138
else:
139+
extra_kwargs = {}
140+
if self.config.experimental_attention_variant == 'dsa':
141+
# For dsa we need to pass in the original hidden states and the compressed
142+
# query representation.
143+
extra_kwargs['x'] = hidden_states
144+
extra_kwargs['qr'] = q_compressed
145+
# for easy injection of rotary_pos_emb (patch)
146+
packed_seq_params = (packed_seq_params, rotary_pos_emb)
139147
core_attn_out = self.core_attention(
140148
query,
141149
key,
142150
value,
143151
attention_mask,
144152
packed_seq_params=packed_seq_params,
145153
attn_mask_type=attn_mask_type,
154+
**extra_kwargs,
146155
)
147156
if thd_qkv_format:
148157
if core_attn_out.ndim == 2:
@@ -789,6 +798,152 @@ def _new_load_inline(*args, **kwargs):
789798
cpp_extension.load_inline = load_inline
790799

791800

801+
def _patch_dsa():
802+
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
803+
from megatron.core.models.gpt import experimental_attention_variant_module_specs
804+
from megatron.core.packed_seq_params import PackedSeqParams
805+
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
806+
from megatron.core.transformer.experimental_attention_variant.dsa import rotate_activation
807+
DSAIndexer = experimental_attention_variant_module_specs.DSAIndexer
808+
809+
class NewDSAIndexer(DSAIndexer):
810+
811+
def forward_before_topk(
812+
self,
813+
x: torch.Tensor,
814+
qr: torch.Tensor,
815+
packed_seq_params: Optional[PackedSeqParams] = None,
816+
):
817+
"""All computations before topk."""
818+
# =========================================
819+
# Gather inputs if sp is enabled
820+
# =========================================
821+
packed_seq_params, rotary_pos_emb = packed_seq_params # patch
822+
assert packed_seq_params is None, 'Packed sequence is not supported for DSAttention'
823+
824+
if self.config.sequence_parallel and self.pg_collection.tp.size() > 1:
825+
x = gather_from_sequence_parallel_region(x, group=self.pg_collection.tp)
826+
qr = gather_from_sequence_parallel_region(qr, group=self.pg_collection.tp)
827+
828+
# =========================================
829+
# Get sequence length and batch size
830+
# =========================================
831+
seqlen, bsz, _ = x.size()
832+
833+
# =========================================
834+
# q linear and apply rope to q
835+
# =========================================
836+
# [seqlen, batch, q_lora_rank] -> [seqlen, batch, index_n_heads * index_head_dim]
837+
q, _ = self.linear_wq_b(qr)
838+
# [seqlen, batch, index_n_heads * index_head_dim]
839+
# -> [seqlen, batch, index_n_heads, index_head_dim]
840+
q = q.reshape(seqlen, bsz, self.index_n_heads, self.index_head_dim)
841+
q = self._apply_rope(q, rotary_pos_emb) # mscale will be passed in by patch
842+
843+
# =========================================
844+
# k linear and apply rope to k
845+
# =========================================
846+
# [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim]
847+
k, _ = self.linear_wk(x)
848+
k = self.k_norm(k)
849+
# [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim]
850+
k = k.reshape(seqlen, bsz, 1, self.index_head_dim)
851+
k = self._apply_rope(k, rotary_pos_emb)
852+
# [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim]
853+
k = k.reshape(seqlen, bsz, self.index_head_dim)
854+
855+
# =========================================
856+
# Rotate activation
857+
# =========================================
858+
q = rotate_activation(q)
859+
k = rotate_activation(k)
860+
861+
# =========================================
862+
# Prepare weights for index scores
863+
# =========================================
864+
# [seqlen, batch, hidden_size] -> [seqlen, batch, index_n_heads]
865+
weights, _ = self.linear_weights_proj(x)
866+
weights = weights * (self.index_n_heads**-0.5) * self.softmax_scale
867+
868+
return q, k, weights
869+
870+
def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor):
871+
"""Apply RoPE to the input tensor."""
872+
# x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim]
873+
# x_pe [seqlen, batch, *, qk_pos_emb_head_dim]
874+
x_pe, x_nope = torch.split(
875+
x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1)
876+
x_pe = apply_rotary_pos_emb(
877+
x_pe,
878+
rotary_pos_emb,
879+
config=self.config,
880+
cu_seqlens=None,
881+
cp_group=self.pg_collection.cp,
882+
)
883+
# [seqlen, batch, *, index_head_dim]
884+
x = torch.cat([x_pe, x_nope], dim=-1)
885+
return x
886+
887+
def forward_with_scores(
888+
self,
889+
x: torch.Tensor,
890+
qr: torch.Tensor,
891+
mask: Optional[torch.Tensor] = None,
892+
packed_seq_params: Optional[PackedSeqParams] = None,
893+
) -> Tuple[torch.Tensor, torch.Tensor]:
894+
"""
895+
Forward pass for DSA Indexer that returns both index scores and top-k indices.
896+
897+
This is used when KL loss is enabled to compare indexer scores with true attention scores.
898+
899+
Args:
900+
x: hidden states [seqlen, batch, hidden_size].
901+
qr: Low-rank query tensor [seqlen, batch, q_lora_rank].
902+
mask: Attention mask [batch, seqlen, seqlen].
903+
packed_seq_params: Packed sequence parameters for variable length sequences.
904+
905+
Returns:
906+
index_scores: Index scores [batch, seqlen, seqlen].
907+
topk_indices: Top-k indices [batch, seqlen, index_topk].
908+
"""
909+
try:
910+
from megatron.core.transformer.experimental_attention_variant.dsa import fused_qk_topk_naive
911+
except ImportError:
912+
raise ImportError('fused_qk_topk_naive is not available. Please install megatron-core from source. '
913+
'`pip install git+https://github.com/NVIDIA/Megatron-LM.git`')
914+
# [seqlen, batch, index_n_heads * index_head_dim]
915+
# [seqlen, batch, index_head_dim]
916+
# [seqlen, batch, index_n_heads]
917+
q, k, weights = self.forward_before_topk(x, qr, packed_seq_params)
918+
919+
# [batch, seqlen, seqlen], [batch, seqlen, index_topk]
920+
index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, self.index_topk, mask)
921+
922+
return index_scores, topk_indices
923+
924+
def forward(self,
925+
x: torch.Tensor,
926+
qr: torch.Tensor,
927+
mask: Optional[torch.Tensor] = None,
928+
packed_seq_params: Optional[PackedSeqParams] = None):
929+
"""
930+
Forward pass for DSA Indexer.
931+
932+
Args:
933+
x: hidden states [seqlen, batch, hidden_size].
934+
qr: Low-rank query tensor [seqlen, batch, q_lora_rank].
935+
mask: Attention mask [batch, seqlen, seqlen].
936+
packed_seq_params: Packed sequence parameters for variable length sequences.
937+
938+
Returns:
939+
topk_indices: Top-k indices for sparse attention [batch, seqlen, index_topk].
940+
"""
941+
_, topk_indices = self.forward_with_scores(x, qr, mask, packed_seq_params)
942+
return topk_indices
943+
944+
experimental_attention_variant_module_specs.DSAIndexer = NewDSAIndexer
945+
946+
792947
def init_megatron_env():
793948
os.environ.pop('VLLM_USE_MODELSCOPE', None)
794949
logging_level = logging.root.level
@@ -804,6 +959,10 @@ def init_megatron_env():
804959
_patch_mrope()
805960
_patch__write_item()
806961
_patch_mtp()
962+
try:
963+
_patch_dsa()
964+
except ImportError:
965+
pass
807966
logging.root.setLevel(logging_level) # revert logger level
808967
from swift.megatron import tuners # patch lora
809968
try:

swift/megatron/model/gpt_bridge.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def _get_hf_grouped(self):
737737
if self.model_type in {
738738
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe',
739739
'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe',
740-
'qwen3_5_moe'
740+
'qwen3_5_moe', 'glm_moe_dsa'
741741
}:
742742
return False, False
743743
return None, None
@@ -1257,6 +1257,22 @@ def _set_mlp_state(
12571257
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
12581258
return hf_state_dict
12591259

1260+
def _set_indexer(self, mg_indexer, hf_state_dict, hf_prefix: str, to_mcore: bool):
1261+
if to_mcore:
1262+
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
1263+
else:
1264+
hf_state_dict = {}
1265+
self._set_state_dict(mg_indexer, 'linear_wq_b.weight', hf_state_dict, 'wq_b.weight', to_mcore)
1266+
self._set_state_dict(mg_indexer, 'linear_wk.weight', hf_state_dict, 'wk.weight', to_mcore)
1267+
self._set_state_dict(mg_indexer, 'k_norm.weight', hf_state_dict, 'k_norm.weight', to_mcore)
1268+
self._set_state_dict(mg_indexer, 'k_norm.bias', hf_state_dict, 'k_norm.bias', to_mcore)
1269+
self._set_state_dict(mg_indexer, 'linear_weights_proj.weight', hf_state_dict, 'weights_proj.weight', to_mcore)
1270+
if to_mcore:
1271+
hf_state_dict = {}
1272+
else:
1273+
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
1274+
return hf_state_dict
1275+
12601276
def _set_mla_attn_state(
12611277
self,
12621278
mg_attn,
@@ -1279,11 +1295,18 @@ def _set_mla_attn_state(
12791295
to_mcore)
12801296
self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore)
12811297
if self.config.qk_layernorm:
1282-
if self.config.q_lora_rank is not None:
1283-
self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict,
1284-
'q_a_layernorm.weight', to_mcore)
1285-
self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight',
1286-
to_mcore)
1298+
if self.config.experimental_attention_variant == 'dsa':
1299+
if self.config.q_lora_rank is not None:
1300+
self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, 'q_a_layernorm.weight', to_mcore)
1301+
self._set_state_dict(mg_attn, 'kv_layernorm.weight', hf_state_dict, 'kv_a_layernorm.weight', to_mcore)
1302+
else:
1303+
if self.config.q_lora_rank is not None:
1304+
self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict,
1305+
'q_a_layernorm.weight', to_mcore)
1306+
self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict,
1307+
'kv_a_layernorm.weight', to_mcore)
1308+
if self.config.experimental_attention_variant == 'dsa':
1309+
hf_state_dict.update(self._set_indexer(mg_attn.core_attention.indexer, hf_state_dict, 'indexer.', to_mcore))
12871310
if to_mcore:
12881311
hf_state_dict = {}
12891312
else:

0 commit comments

Comments
 (0)