diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 14160be3e..17a110de1 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -869,9 +869,14 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str post_state_dict[prefix + new_suffix] = value break - # Squeeze tensors with a leading dimension of 1 + # Squeeze scales with a leading dimension of 1 for key, value in post_state_dict.items(): - if isinstance(value, torch.Tensor) and value.dim() == 3 and value.shape[0] == 1: + if ( + "scale" in key + and isinstance(value, torch.Tensor) + and value.dim() == 3 + and value.shape[0] == 1 + ): post_state_dict[key] = value.squeeze(0) # remove real quant parameters from the state dict