Skip to content

Commit 521185a

Browse files
committed
save initial changes for hack
1 parent e5c2de4 commit 521185a

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -899,23 +899,25 @@ def __init__(
899899
requires_grad=False,
900900
)
901901

902-
# Compute the correct rank for the combined TP*CP mapping.
903-
# The attention heads are split first by TP, then by CP within each TP group.
904-
# Original rank order: pp_rank * tp_size * cp_size + cp_rank * tp_size + tp_rank
905-
# For o_proj, we need: pp_rank * tp_size * cp_size + tp_rank * cp_size + cp_rank
906-
# This ensures weight slices align with the actual head partitions.
907-
new_rank_for_o = (self.mapping.pp_rank * tp_size * cp_size +
908-
self.mapping.tp_rank * cp_size + self.mapping.cp_rank)
909-
print(f"[MLA::create_weights][rank {self.mapping.rank}][cp_rank {self.mapping.cp_rank}][tp_rank {self.mapping.tp_rank}]: new_rank_for_o: {new_rank_for_o}")
902+
# # Compute the correct rank for the combined TP*CP mapping.
903+
# # The attention heads are split first by TP, then by CP within each TP group.
904+
# # Original rank order: pp_rank * tp_size * cp_size + cp_rank * tp_size + tp_rank
905+
# # For o_proj, we need: pp_rank * tp_size * cp_size + tp_rank * cp_size + cp_rank
906+
# # This ensures weight slices align with the actual head partitions.
907+
# new_rank_for_o = (self.mapping.pp_rank * tp_size * cp_size +
908+
# self.mapping.tp_rank * cp_size + self.mapping.cp_rank)
909+
# print(f"[MLA::create_weights][rank {self.mapping.rank}][cp_rank {self.mapping.cp_rank}][tp_rank {self.mapping.tp_rank}]: new_rank_for_o: {new_rank_for_o}")
910910
mapping_o = Mapping(
911911
world_size=tp_size * pp_size * cp_size,
912912
tp_size=tp_size * cp_size,
913913
pp_size=pp_size,
914914
cp_size=1,
915-
rank=new_rank_for_o,
915+
rank=self.mapping.rank,
916916
gpus_per_node=self.mapping.gpus_per_node,
917917
enable_attention_dp=self.mapping.enable_attention_dp,
918918
)
919+
# TODO: Update this for all layers.
920+
weight_name = "o_proj_with_cp" if self.mapping.has_cp_helix() and self.layer_idx == 0 else None
919921
self.o_proj = Linear(
920922
self.num_key_value_heads * self.v_head_dim,
921923
self.hidden_size,

tensorrt_llm/_torch/modules/linear.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def load_weight_shard(
6868
tensor_parallel_mode: Optional[TensorParallelMode] = None,
6969
device: torch.device = torch.device('cpu'),
7070
return_slice_indices: bool = False,
71+
weight_name: Optional[str] = None,
7172
) -> torch.Tensor:
73+
if weight_name is not None:
74+
print(f"[load_weight_shard] weight_name: {weight_name}")
7275
# Skip device transfers on integrated GPUs to conserve shared memory
7376
if weight.device.type != device.type and is_device_integrated():
7477
# For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory.
@@ -112,6 +115,10 @@ def maybe_convert_to_torch_tensor(
112115
if width == 1:
113116
return maybe_convert_to_torch_tensor(weight)
114117

118+
if weight_name is not None and tensor_parallel_rank == 1:
119+
print(f"[load_weight_shard] THIS IS WHERE YOU SWAP RANK 1.")
120+
if weight_name is not None and tensor_parallel_rank == 2:
121+
print(f"[load_weight_shard] THIS IS WHERE YOU SWAP RANK 2.")
115122
slice_width = math.ceil(width / tensor_parallel_size)
116123
slice_start = tensor_parallel_rank * slice_width
117124
slice_end = min((tensor_parallel_rank + 1) * slice_width, width)
@@ -140,7 +147,10 @@ def load_weights_vanilla_helper(module: Linear,
140147
weights: List[Dict],
141148
weight_transform=lambda x: x,
142149
bias_transform=lambda x: x,
143-
allow_partial_loading: bool = False):
150+
allow_partial_loading: bool = False,
151+
weight_name: Optional[str] = None):
152+
if weight_name is not None:
153+
print(f"[load_weights_vanilla_helper] weight_name: {weight_name}")
144154
assert len(weights) == 1
145155
if not allow_partial_loading:
146156
assert "weight" in weights[0]
@@ -150,7 +160,7 @@ def load_weights_vanilla_helper(module: Linear,
150160

151161
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
152162
module.tp_rank, module.tp_mode,
153-
device) if "weight" in weights[0] else None
163+
device, weight_name=weight_name) if "weight" in weights[0] else None
154164

155165
if weight is not None:
156166
if module.has_weight_only_quant:
@@ -167,7 +177,7 @@ def load_weights_vanilla_helper(module: Linear,
167177
if module.bias is not None:
168178
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
169179
module.tp_rank, module.tp_mode,
170-
device) if "bias" in weights[0] else None
180+
device, weight_name=weight_name) if "bias" in weights[0] else None
171181
if bias is not None:
172182
copy_weight(module.bias, bias_transform(bias))
173183

@@ -311,7 +321,8 @@ def load_weights(self,
311321
module: Linear,
312322
weights: List[Dict],
313323
weight_mode: WeightMode,
314-
allow_partial_loading: bool = False):
324+
allow_partial_loading: bool = False,
325+
weight_name: Optional[str] = None):
315326
"""
316327
Load weights from the checkpoint.
317328
"""
@@ -396,10 +407,14 @@ def apply(self, module: Linear, input: torch.Tensor,
396407
def load_weights_vanilla(self,
397408
module: Linear,
398409
weights: List[Dict],
399-
allow_partial_loading: bool = False) -> None:
410+
allow_partial_loading: bool = False,
411+
weight_name: Optional[str] = None) -> None:
412+
if weight_name is not None:
413+
print(f"[UnquantizedLinearMethod::load_weights_vanilla] weight_name: {weight_name}")
400414
load_weights_vanilla_helper(module,
401415
weights,
402-
allow_partial_loading=allow_partial_loading)
416+
allow_partial_loading=allow_partial_loading,
417+
weight_name=weight_name)
403418

404419
def load_weights_fused_qkv_linear(
405420
self,
@@ -2058,6 +2073,8 @@ def __init__(
20582073
disable_deep_gemm: bool = False,
20592074
fused_weight_shard_indices_mapping: Optional[dict] = None,
20602075
nvfp4_allowed_backends: Optional[List[str]] = None,
2076+
weight_name: Optional[str] = None,
2077+
mapping_with_cp: Optional[Mapping] = None,
20612078
):
20622079
"""
20632080
Args:
@@ -2098,6 +2115,12 @@ def __init__(
20982115
'cutlass', 'cublaslt', 'cuda_core'
20992116
]
21002117

2118+
if mapping_with_cp is not None and weight_name == "o_proj":
2119+
print("[Linear::__init__] Found o_proj with CP mapping. Setting weight_name to o_proj_with_cp.")
2120+
self.weight_name = "o_proj_with_cp"
2121+
else:
2122+
self.weight_name = None
2123+
21012124
local_in_features = in_features
21022125
local_out_features = out_features
21032126

@@ -2284,7 +2307,8 @@ def load_weights(self,
22842307
self,
22852308
weights,
22862309
weight_mode,
2287-
allow_partial_loading=allow_partial_loading)
2310+
allow_partial_loading=allow_partial_loading,
2311+
weight_name=self.weight_name)
22882312

22892313
def post_load_weights(self):
22902314
self.quant_method.post_load_weights(self)

0 commit comments

Comments
 (0)