Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def backward(ctx, grad_outputs):
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if compute_bias_grad is not None:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ class CompressConfig(ModeloptBaseConfig):

compress: dict[str, bool] = ModeloptField(
default={"*": True},
title="""Enable weight compression for the given pattern. Default is False for all weights.
title="""Enable weight compression for the given pattern. Default is True for all weights.
Call `compress` function to compress the model weights.""",
)

Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
input, *args, **kwargs
if (
self._should_run_real_quant_gemm
and self.get_real_quant_gemm_impl(input, *args, **kwargs)
and input.numel() > 1
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
Expand Down
Loading