From 867922a006720d02f4b75f97081d7dac10e2f66f Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Fri, 5 Sep 2025 10:47:56 -0700 Subject: [PATCH 1/3] Do not squeeze weights with leading dim 1 Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 14160be3e..2dea36dd7 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -869,11 +869,6 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str post_state_dict[prefix + new_suffix] = value break - # Squeeze tensors with a leading dimension of 1 - for key, value in post_state_dict.items(): - if isinstance(value, torch.Tensor) and value.dim() == 3 and value.shape[0] == 1: - post_state_dict[key] = value.squeeze(0) - # remove real quant parameters from the state dict keys_to_delete = [] for key, value in post_state_dict.items(): From 76ce70a084c1e1e2adc2dd684c3e02a206f5a55e Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Fri, 5 Sep 2025 13:12:30 -0700 Subject: [PATCH 2/3] Update quant_utils.py Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 2dea36dd7..c7d531604 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -869,6 +869,16 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str post_state_dict[prefix + new_suffix] = value break + # Squeeze scales with a leading dimension of 1 + for key, value in post_state_dict.items(): + if ( + "scale" in key + and isinstance(value, torch.Tensor) + and value.dim() == 3 + and value.shape[0] == 1 + ): + post_state_dict[key] = value.squeeze(0) + # remove real quant parameters from the state dict keys_to_delete = [] for key, value in post_state_dict.items(): From 501d8f0e12786e6971a58c499757781cae9e7c7c Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Fri, 5 Sep 2025 15:12:23 -0700 Subject: [PATCH 3/3] Update quant_utils.py Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- modelopt/torch/export/quant_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index c7d531604..17a110de1 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -878,7 +878,7 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str and value.shape[0] == 1 ): post_state_dict[key] = value.squeeze(0) - + # remove real quant parameters from the state dict keys_to_delete = [] for key, value in post_state_dict.items():