Skip to content

Commit 7c5c913

Browse files
danielafrimimikeiovine
authored andcommitted
[https://nvbugs/5524714][fix] Fix TP sharding of fused-QKV weight scales in W4A16 AWQ (#8432)
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 1b9fbbb commit 7c5c913

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,7 +1366,7 @@ def create_weights(self, module: Linear, in_features: int,
13661366
group_size = module.quant_config.group_size
13671367
if in_features % group_size != 0:
13681368
raise ValueError(
1369-
f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) "
1369+
f"in_features ({in_features}) must be divisible by group_size ({group_size}) "
13701370
f"for INT4 per-group quantization scale dimensions.")
13711371

13721372
module.weight_scale = Parameter(torch.empty(
@@ -1465,7 +1465,8 @@ def load_weights_fused_qkv_linear(self, module: Linear,
14651465

14661466
copy_weight(module.weight, fused_weight)
14671467

1468-
weight_scales = self.load_weight_scales(weights)
1468+
weight_scales = self.load_weight_scales(weights, module.tp_size,
1469+
module.tp_rank, module.tp_mode)
14691470

14701471
# Create concatenated weight scale tensor
14711472
cat_weight_scale = torch.cat(weight_scales, dim=0).T.contiguous()

0 commit comments

Comments
 (0)