Skip to content

Commit 760c583

Browse files
committed
catch input.numel()<1 in _RealQuantMegatronParallelLinear
Signed-off-by: Ye Yu <[email protected]>
1 parent cf6f1d4 commit 760c583

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
418418
while the original forward only takes 1 positional argument. We must above the fallback path
419419
in RealQuantLinear.forward().
420420
"""
421-
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
422-
input, *args, **kwargs
421+
if (
422+
self._should_run_real_quant_gemm
423+
and self.get_real_quant_gemm_impl(input, *args, **kwargs)
424+
and input.numel() > 1
423425
):
424426
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
425427
tp_group = kwargs.get("tp_group")

0 commit comments

Comments
 (0)