@@ -649,20 +649,25 @@ def _run_mla_distributed(
649649 _copy_to_tp_then_cp (weights , "o_proj.weight" , dim = 1 , tp_rank = tp_rank , tp_size = tp_size ,
650650 cp_rank = cp_rank , cp_size = cp_size )
651651
652- # 2. v_b_proj: Shape (num_heads, v_head_dim, kv_lora_rank)
652+ # 2. q_proj.weight: Column parallel by both TP and CP
653+ # Shape: (num_heads * qk_head_dim, hidden_size) -> shard dim 0
654+ _copy_to_tp_then_cp (weights , "q_proj.weight" , dim = 0 , tp_rank = tp_rank , tp_size = tp_size ,
655+ cp_rank = cp_rank , cp_size = cp_size )
656+
657+ # 3. v_b_proj: Shape (num_heads, v_head_dim, kv_lora_rank)
653658 # Sharded by both TP and CP on head dimension (dim 0)
654659 _copy_to_tp_then_cp (weights , "v_b_proj" , dim = 0 , tp_rank = tp_rank , tp_size = tp_size ,
655660 cp_rank = cp_rank , cp_size = cp_size )
656661
657- # 3 . k_b_proj_trans: Shape (num_heads_tp, kv_lora_rank, qk_nope_head_dim)
662+ # 4 . k_b_proj_trans: Shape (num_heads_tp, kv_lora_rank, qk_nope_head_dim)
658663 # Sharded by TP only (not CP) - used in generation phase
659664 _copy_to_tp (weights , "k_b_proj_trans" , dim = 0 , tp_rank = tp_rank , tp_size = tp_size )
660665
661- # 4 . q_b_proj.weight: Column parallel by TP only
666+ # 5 . q_b_proj.weight: Column parallel by TP only
662667 # Shape: (num_heads * qk_head_dim, q_lora_rank) -> shard dim 0
663668 _copy_to_tp (weights , "q_b_proj.weight" , dim = 0 , tp_rank = tp_rank , tp_size = tp_size )
664669
665- # 5 . kv_b_proj.weight: Column parallel by TP only
670+ # 6 . kv_b_proj.weight: Column parallel by TP only
666671 # Shape: (num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank) -> shard dim 0
667672 _copy_to_tp (weights , "kv_b_proj.weight" , dim = 0 , tp_rank = tp_rank , tp_size = tp_size )
668673
0 commit comments