Skip to content

Commit 7cebe56

Browse files
committed
useless changes to fix weight loading - pure CP tests fails with 99% mismatch
1 parent 4dce0ea commit 7cebe56

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -649,20 +649,25 @@ def _run_mla_distributed(
649649
_copy_to_tp_then_cp(weights, "o_proj.weight", dim=1, tp_rank=tp_rank, tp_size=tp_size,
650650
cp_rank=cp_rank, cp_size=cp_size)
651651

652-
# 2. v_b_proj: Shape (num_heads, v_head_dim, kv_lora_rank)
652+
# 2. q_proj.weight: Column parallel by both TP and CP
653+
# Shape: (num_heads * qk_head_dim, hidden_size) -> shard dim 0
654+
_copy_to_tp_then_cp(weights, "q_proj.weight", dim=0, tp_rank=tp_rank, tp_size=tp_size,
655+
cp_rank=cp_rank, cp_size=cp_size)
656+
657+
# 3. v_b_proj: Shape (num_heads, v_head_dim, kv_lora_rank)
653658
# Sharded by both TP and CP on head dimension (dim 0)
654659
_copy_to_tp_then_cp(weights, "v_b_proj", dim=0, tp_rank=tp_rank, tp_size=tp_size,
655660
cp_rank=cp_rank, cp_size=cp_size)
656661

657-
# 3. k_b_proj_trans: Shape (num_heads_tp, kv_lora_rank, qk_nope_head_dim)
662+
# 4. k_b_proj_trans: Shape (num_heads_tp, kv_lora_rank, qk_nope_head_dim)
658663
# Sharded by TP only (not CP) - used in generation phase
659664
_copy_to_tp(weights, "k_b_proj_trans", dim=0, tp_rank=tp_rank, tp_size=tp_size)
660665

661-
# 4. q_b_proj.weight: Column parallel by TP only
666+
# 5. q_b_proj.weight: Column parallel by TP only
662667
# Shape: (num_heads * qk_head_dim, q_lora_rank) -> shard dim 0
663668
_copy_to_tp(weights, "q_b_proj.weight", dim=0, tp_rank=tp_rank, tp_size=tp_size)
664669

665-
# 5. kv_b_proj.weight: Column parallel by TP only
670+
# 6. kv_b_proj.weight: Column parallel by TP only
666671
# Shape: (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank) -> shard dim 0
667672
_copy_to_tp(weights, "kv_b_proj.weight", dim=0, tp_rank=tp_rank, tp_size=tp_size)
668673

0 commit comments

Comments
 (0)