Skip to content

Commit 76ce70a

Browse files
authored
Update quant_utils.py
Signed-off-by: Chenjie Luo <[email protected]>
1 parent 867922a commit 76ce70a

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,16 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
869869
post_state_dict[prefix + new_suffix] = value
870870
break
871871

872+
# Squeeze scales with a leading dimension of 1
873+
for key, value in post_state_dict.items():
874+
if (
875+
"scale" in key
876+
and isinstance(value, torch.Tensor)
877+
and value.dim() == 3
878+
and value.shape[0] == 1
879+
):
880+
post_state_dict[key] = value.squeeze(0)
881+
872882
# remove real quant parameters from the state dict
873883
keys_to_delete = []
874884
for key, value in post_state_dict.items():

0 commit comments

Comments
 (0)