Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
input, *args, **kwargs
if (
self._should_run_real_quant_gemm
and self.get_real_quant_gemm_impl(input, *args, **kwargs)
and input.numel() > 1
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
Expand Down
Loading