Skip to content

Commit 2b58dba

Browse files
danielafrimichzblych
authored andcommitted
[https://nvbugs/5524714][fix] Fix TP sharding of fused-QKV weight scales in W4A16 AWQ (NVIDIA#8432)
Signed-off-by: Daniel Afrimi <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent ce23e24 commit 2b58dba

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,7 @@ def create_weights(self, module: Linear, in_features: int,
13931393
group_size = module.quant_config.group_size
13941394
if in_features % group_size != 0:
13951395
raise ValueError(
1396-
f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) "
1396+
f"in_features ({in_features}) must be divisible by group_size ({group_size}) "
13971397
f"for INT4 per-group quantization scale dimensions.")
13981398

13991399
module.weight_scale = Parameter(torch.empty(
@@ -1492,7 +1492,8 @@ def load_weights_fused_qkv_linear(self, module: Linear,
14921492

14931493
copy_weight(module.weight, fused_weight)
14941494

1495-
weight_scales = self.load_weight_scales(weights)
1495+
weight_scales = self.load_weight_scales(weights, module.tp_size,
1496+
module.tp_rank, module.tp_mode)
14961497

14971498
# Create concatenated weight scale tensor
14981499
cat_weight_scale = torch.cat(weight_scales, dim=0).T.contiguous()

0 commit comments

Comments
 (0)