@@ -678,7 +678,7 @@ def fp8_block_scaling_bmm_out(
678678 else :
679679 raise NotImplementedError (f"SM{ sm_version } is not supported" )
680680
681- TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/pureTP4_redo /"
681+ TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/mixedTP2CP2_fix /"
682682def save_tensor_mla (tensor : torch .Tensor , filename : str , rank : int , cp_rank : int , tp_rank : int ):
683683 os .makedirs (TENSOR_SAVE_DIR , exist_ok = True )
684684 filepath = os .path .join (TENSOR_SAVE_DIR , f"rank{ rank } _cp{ cp_rank } _tp{ tp_rank } _{ filename } .pt" )
@@ -899,12 +899,20 @@ def __init__(
899899 requires_grad = False ,
900900 )
901901
902+ # Compute the correct rank for the combined TP*CP mapping.
903+ # The attention heads are split first by TP, then by CP within each TP group.
904+ # Original rank order: pp_rank * tp_size * cp_size + cp_rank * tp_size + tp_rank
905+ # For o_proj, we need: pp_rank * tp_size * cp_size + tp_rank * cp_size + cp_rank
906+ # This ensures weight slices align with the actual head partitions.
907+ new_rank_for_o = (self .mapping .pp_rank * tp_size * cp_size +
908+ self .mapping .tp_rank * cp_size + self .mapping .cp_rank )
909+ print (f"[MLA::create_weights][rank { self .mapping .rank } ][cp_rank { self .mapping .cp_rank } ][tp_rank { self .mapping .tp_rank } ]: new_rank_for_o: { new_rank_for_o } " )
902910 mapping_o = Mapping (
903911 world_size = tp_size * pp_size * cp_size ,
904912 tp_size = tp_size * cp_size ,
905913 pp_size = pp_size ,
906914 cp_size = 1 ,
907- rank = self . mapping . rank ,
915+ rank = new_rank_for_o ,
908916 gpus_per_node = self .mapping .gpus_per_node ,
909917 enable_attention_dp = self .mapping .enable_attention_dp ,
910918 )
0 commit comments