We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cf6f1d4 commit 760c583Copy full SHA for 760c583
modelopt/torch/quantization/plugins/megatron.py
@@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
418
while the original forward only takes 1 positional argument. We must above the fallback path
419
in RealQuantLinear.forward().
420
"""
421
- if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
422
- input, *args, **kwargs
+ if (
+ self._should_run_real_quant_gemm
423
+ and self.get_real_quant_gemm_impl(input, *args, **kwargs)
424
+ and input.numel() > 1
425
):
426
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
427
tp_group = kwargs.get("tp_group")
0 commit comments