Skip to content

Commit 512dbb7

Browse files
authored
Avoid squeezing original weight tensors with leading dim 1 (#294)
Signed-off-by: Chenjie Luo <[email protected]>
1 parent cf6f1d4 commit 512dbb7

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,14 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
869869
post_state_dict[prefix + new_suffix] = value
870870
break
871871

872-
# Squeeze tensors with a leading dimension of 1
872+
# Squeeze scales with a leading dimension of 1
873873
for key, value in post_state_dict.items():
874-
if isinstance(value, torch.Tensor) and value.dim() == 3 and value.shape[0] == 1:
874+
if (
875+
"scale" in key
876+
and isinstance(value, torch.Tensor)
877+
and value.dim() == 3
878+
and value.shape[0] == 1
879+
):
875880
post_state_dict[key] = value.squeeze(0)
876881

877882
# remove real quant parameters from the state dict

0 commit comments

Comments
 (0)