From 760c583a3048d0d09e3ee6588343a2dd36960c71 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 8 Sep 2025 14:47:59 -0700 Subject: [PATCH 1/3] catch input.numel()<1 in _RealQuantMegatronParallelLinear Signed-off-by: Ye Yu --- modelopt/torch/quantization/plugins/megatron.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..3a442f24d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar while the original forward only takes 1 positional argument. We must above the fallback path in RealQuantLinear.forward(). """ - if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl( - input, *args, **kwargs + if ( + self._should_run_real_quant_gemm + and self.get_real_quant_gemm_impl(input, *args, **kwargs) + and input.numel() > 1 ): allreduce_dgrad = kwargs.get("allreduce_dgrad", False) tp_group = kwargs.get("tp_group") From b4cb2d55ca17bbbbfc1107d1415394aa4321af36 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Sep 2025 09:20:01 -0700 Subject: [PATCH 2/3] fix fp8 gemm bug Signed-off-by: Ye Yu --- modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index 1b61864c0..98c056a0b 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -175,7 +175,9 @@ def backward(ctx, grad_outputs): weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) if compute_bias_grad is not None: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) From fc445959f62f6bc60215ed11679ae3bdbd93337d Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 9 Sep 2025 09:37:20 -0700 Subject: [PATCH 3/3] fix a typo Signed-off-by: Ye Yu --- modelopt/torch/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6d492ffbf..7bdc0b4c6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1133,7 +1133,7 @@ class CompressConfig(ModeloptBaseConfig): compress: dict[str, bool] = ModeloptField( default={"*": True}, - title="""Enable weight compression for the given pattern. Default is False for all weights. + title="""Enable weight compression for the given pattern. Default is True for all weights. Call `compress` function to compress the model weights.""", )