Skip to content

Commit 7e0223f

Browse files
committed
save potential fix
1 parent dd24629 commit 7e0223f

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def moe_reduce_add_shared_output(routed_output, shared_output):
132132
return shared_output + routed_output
133133

134134

135-
TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/pureTP4_redo/"
135+
TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/mixedTP2CP2_fix/"
136136
def save_tensor(tensor: torch.Tensor, filename: str, rank: int, cp_rank: int, tp_rank: int):
137137
os.makedirs(TENSOR_SAVE_DIR, exist_ok=True)
138138
filepath = os.path.join(TENSOR_SAVE_DIR, f"rank{rank}_cp{cp_rank}_tp{tp_rank}_{filename}.pt")

tensorrt_llm/_torch/modules/attention.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def fp8_block_scaling_bmm_out(
678678
else:
679679
raise NotImplementedError(f"SM{sm_version} is not supported")
680680

681-
TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/pureTP4_redo/"
681+
TENSOR_SAVE_DIR = "/home/bbuddharaju/scratch/TensorRT-LLM/mixedTP2CP2_fix/"
682682
def save_tensor_mla(tensor: torch.Tensor, filename: str, rank: int, cp_rank: int, tp_rank: int):
683683
os.makedirs(TENSOR_SAVE_DIR, exist_ok=True)
684684
filepath = os.path.join(TENSOR_SAVE_DIR, f"rank{rank}_cp{cp_rank}_tp{tp_rank}_{filename}.pt")
@@ -899,12 +899,20 @@ def __init__(
899899
requires_grad=False,
900900
)
901901

902+
# Compute the correct rank for the combined TP*CP mapping.
903+
# The attention heads are split first by TP, then by CP within each TP group.
904+
# Original rank order: pp_rank * tp_size * cp_size + cp_rank * tp_size + tp_rank
905+
# For o_proj, we need: pp_rank * tp_size * cp_size + tp_rank * cp_size + cp_rank
906+
# This ensures weight slices align with the actual head partitions.
907+
new_rank_for_o = (self.mapping.pp_rank * tp_size * cp_size +
908+
self.mapping.tp_rank * cp_size + self.mapping.cp_rank)
909+
print(f"[MLA::create_weights][rank {self.mapping.rank}][cp_rank {self.mapping.cp_rank}][tp_rank {self.mapping.tp_rank}]: new_rank_for_o: {new_rank_for_o}")
902910
mapping_o = Mapping(
903911
world_size=tp_size * pp_size * cp_size,
904912
tp_size=tp_size * cp_size,
905913
pp_size=pp_size,
906914
cp_size=1,
907-
rank=self.mapping.rank,
915+
rank=new_rank_for_o,
908916
gpus_per_node=self.mapping.gpus_per_node,
909917
enable_attention_dp=self.mapping.enable_attention_dp,
910918
)

tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ context_servers:
2020
- "localhost:8001"
2121
generation_servers:
2222
num_instances: 1
23-
tensor_parallel_size: 4
23+
tensor_parallel_size: 2
2424
pipeline_parallel_size: 1
25-
context_parallel_size: 1
25+
context_parallel_size: 2
2626
enable_chunked_prefill: False
2727
cp_config:
2828
cp_type: HELIX

0 commit comments

Comments
 (0)