diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index 1b61864c0..98c056a0b 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -175,7 +175,9 @@ def backward(ctx, grad_outputs): weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) if compute_bias_grad is not None: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6d492ffbf..7bdc0b4c6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1133,7 +1133,7 @@ class CompressConfig(ModeloptBaseConfig): compress: dict[str, bool] = ModeloptField( default={"*": True}, - title="""Enable weight compression for the given pattern. Default is False for all weights. + title="""Enable weight compression for the given pattern. Default is True for all weights. Call `compress` function to compress the model weights.""", ) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..3a442f24d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar while the original forward only takes 1 positional argument. We must above the fallback path in RealQuantLinear.forward(). """ - if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl( - input, *args, **kwargs + if ( + self._should_run_real_quant_gemm + and self.get_real_quant_gemm_impl(input, *args, **kwargs) + and input.numel() > 1 ): allreduce_dgrad = kwargs.get("allreduce_dgrad", False) tp_group = kwargs.get("tp_group")