diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f514e660d..2a69831e9 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -332,6 +332,10 @@ def _export_quantized_weight( setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False)) + # Register the corrected weight_scale as a buffer + if weight_scale is not None: + sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + def _export_hf_checkpoint( model: nn.Module, dtype: torch.dtype | None = None