Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def backward(ctx, grad_outputs):
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if compute_bias_grad is not None:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))
Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
input, *args, **kwargs
if (
self._should_run_real_quant_gemm
and self.get_real_quant_gemm_impl(input, *args, **kwargs)
and input.numel() > 1
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
Expand Down
Loading