From 05a2e78367efbcd8af00807878565d98a8e73da2 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 24 Sep 2025 23:18:52 +0000 Subject: [PATCH 1/6] Improve realquant gemm impl Signed-off-by: Chenjie Luo --- .../backends/fp8_per_tensor_gemm.py | 61 +++++++++++-------- .../torch/quantization/backends/nvfp4_gemm.py | 10 +-- .../quantization/nn/modules/quant_linear.py | 23 ++++--- .../torch/quantization/plugins/megatron.py | 8 ++- 4 files changed, 56 insertions(+), 46 deletions(-) diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index 1b61864c0..cc5be9d56 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -32,35 +32,38 @@ FP8_MAX = torch.finfo(torch.float8_e4m3fn).max -def fp8_per_tensor_gemm(quant_module, input, bias=None): - """GEMM function for fp8 per tensor quantization.""" +@torch.compile(dynamic=True) +def _to_fp8(x, scale): + return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + + +@torch.compile(dynamic=True) +def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): + input_shape = input.shape + input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) + weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() + output = torch._scaled_mm( + input_fp8, + weight_fp8_t, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + out_dtype=input.dtype, + use_fast_accum=True, + ) + return output.reshape(*input_shape[:-1], output.shape[-1]) - @torch.compile(dynamic=True) - def _to_fp8(x, scale): - return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) - - @torch.compile(dynamic=True) - def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): - input_shape = input.shape - input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) - weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() - output = torch._scaled_mm( - input_fp8, - weight_fp8_t, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - out_dtype=input.dtype, - use_fast_accum=True, - ) - return output.reshape(*input_shape[:-1], output.shape[-1]) +def fp8_per_tensor_gemm(quant_module, input, bias=None): + """GEMM function for fp8 per tensor quantization.""" cached_scale_a = ( hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None ) if not cached_scale_a: - input_amax = quant_module.input_quantizer.amax or reduce_amax(input) + input_amax = quant_module.input_quantizer.amax + if input_amax is None: + input_amax = reduce_amax(input) assert input_amax != 0 quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device) @@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): ) if not cached_scale_b: - weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight) + weight_amax = quant_module.weight_quantizer.amax + if weight_amax is None: + weight_amax = reduce_amax(quant_module.weight) assert weight_amax != 0 quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device) @@ -146,9 +151,9 @@ def forward( ctx.save_for_backward( input_tensor if weight.requires_grad else None, weight if input_tensor.requires_grad else None, - torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None, getattr(quant_module.weight_quantizer, "_scale", None), ) + ctx.compute_bias_grad = bias is not None and bias.requires_grad ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None) ctx.allreduce_dgrad = allreduce_dgrad @@ -166,7 +171,7 @@ def backward(ctx, grad_outputs): dequantize it to compute the input gradient. If the weight is not compressed, we will save the unquantized weight and use it directly to compute the input gradient. """ - input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors + input_tensor, weight, scale = ctx.saved_tensors grad_input = grad_weight = grad_bias = None if weight is not None: if isinstance(weight, QTensorWrapper): @@ -175,8 +180,10 @@ def backward(ctx, grad_outputs): weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor - if compute_bias_grad is not None: + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) + if ctx.compute_bias_grad: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index 4a67a3b93..ffc18fea3 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -139,11 +139,11 @@ def forward( ctx.save_for_backward( input_tensor if weight.requires_grad else None, weight if input_tensor.requires_grad else None, - torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None, getattr(quant_module.weight_quantizer, "_scale", None), getattr(quant_module.weight_quantizer, "_double_scale", None), ) + ctx.compute_bias_grad = bias is not None and bias.requires_grad ctx.allreduce_dgrad = allreduce_dgrad ctx.tp_group = tp_group ret = nvfp4_gemm(quant_module, input_tensor, bias) @@ -158,7 +158,7 @@ def backward(ctx, grad_outputs): dequantize it to compute the input gradient. If the weight is not compressed, we will save the unquantized weight and use it directly to compute the input gradient. """ - input_tensor, weight, compute_bias_grad, scale, double_scale = ctx.saved_tensors + input_tensor, weight, scale, double_scale = ctx.saved_tensors grad_input = grad_weight = grad_bias = None if weight is not None: if isinstance(weight, QTensorWrapper): @@ -173,8 +173,10 @@ def backward(ctx, grad_outputs): ) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, -1) @ input_tensor - if compute_bias_grad is not None: + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) + if ctx.compute_bias_grad is not None: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) diff --git a/modelopt/torch/quantization/nn/modules/quant_linear.py b/modelopt/torch/quantization/nn/modules/quant_linear.py index 584385d0a..44655f7e2 100644 --- a/modelopt/torch/quantization/nn/modules/quant_linear.py +++ b/modelopt/torch/quantization/nn/modules/quant_linear.py @@ -148,7 +148,7 @@ def _should_run_real_quant_gemm(self): and self.allow_real_quant_gemm ) - def get_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool: + def has_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool: """Get the real quant GEMM implementation base on input arguments.""" if not hasattr(self, "_real_quant_gemm_impl"): self._real_quant_gemm_impl = backends.gemm_registry.find_match( @@ -166,20 +166,19 @@ def forward(self, input, *args, **kwargs): return super().forward(input, *args, **kwargs) # Check if real-quant GEMM is available - if self._should_run_real_quant_gemm and input.numel() > 1: - # If the input is not quantized, we use the default GEMM. - self.get_real_quant_gemm_impl(input, *args, **kwargs) - + if ( + self._should_run_real_quant_gemm + and input.numel() > 1 + and self.has_real_quant_gemm_impl(input, *args, **kwargs) + ): # Note: We cache the real-quant GEMM function to avoid matching overhead. # This assumes that the function will not change after the first call. - if self._real_quant_gemm_impl: - with torch.cuda.nvtx.range("RealQuantLinear gemm"): - output = self._real_quant_gemm_impl( - self, input, self.weight, self.bias, *args, **kwargs - ) - return ( - self.output_quantizer(output) if hasattr(self, "output_quantizer") else output + assert self._real_quant_gemm_impl is not None + with torch.cuda.nvtx.range("RealQuantLinear gemm"): + output = self._real_quant_gemm_impl( + self, input, self.weight, self.bias, *args, **kwargs ) + return self.output_quantizer(output) if hasattr(self, "output_quantizer") else output # Otherwise, fallback to the default GEMM return super().forward(input, *args, **kwargs) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..1cf9416ec 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -115,7 +115,7 @@ def real_quant_module_set_extra_state(self, state: Any): """ q_tensor_state = state.get("modelopt_q_tensor_state", None) - if q_tensor_state is not None: + if q_tensor_state: q_tensor_metadata = q_tensor_state["metadata"] q_tensor_metadata["shape"] = self.weight.shape q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"] @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar while the original forward only takes 1 positional argument. We must above the fallback path in RealQuantLinear.forward(). """ - if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl( - input, *args, **kwargs + if ( + self._should_run_real_quant_gemm + and input.numel() > 1 + and self.has_real_quant_gemm_impl(input, *args, **kwargs) ): allreduce_dgrad = kwargs.get("allreduce_dgrad", False) tp_group = kwargs.get("tp_group") From 231f7bda4b2a5aea1de4787870207ee3fd8c071e Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 25 Sep 2025 19:16:01 +0000 Subject: [PATCH 2/6] Update doc Signed-off-by: Chenjie Luo --- docs/source/_templates/autosummary/module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index 44d00656d..647a88898 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -11,7 +11,7 @@ :recursive: {% for item in modules %} {% set full_item = fullname + '.' + item.split('.')[-1] %} -{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' %} +{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' or full_item == 'modelopt.torch.quantization.backends.fp8_per_tensor_gemm' %} {{ full_item }} {% endif %} {%- endfor %} From b82be66d84a4aa17023709d9184e93850d20a767 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 25 Sep 2025 20:26:03 +0000 Subject: [PATCH 3/6] Fix windows Signed-off-by: Chenjie Luo --- modelopt/torch/quantization/backends/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/backends/__init__.py b/modelopt/torch/quantization/backends/__init__.py index 92662c474..a2b690d80 100644 --- a/modelopt/torch/quantization/backends/__init__.py +++ b/modelopt/torch/quantization/backends/__init__.py @@ -15,6 +15,9 @@ """Quantization backends.""" -from .fp8_per_tensor_gemm import * +import platform + +if platform.system() != "Windows": + from .fp8_per_tensor_gemm import * from .gemm_registry import * from .nvfp4_gemm import * From b5e6fc552739dd97f15e26d507c69f8399a3699f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 25 Sep 2025 20:27:17 +0000 Subject: [PATCH 4/6] Fix Signed-off-by: Chenjie Luo --- docs/source/_templates/autosummary/module.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index 647a88898..f6f30fadd 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -11,7 +11,7 @@ :recursive: {% for item in modules %} {% set full_item = fullname + '.' + item.split('.')[-1] %} -{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' or full_item == 'modelopt.torch.quantization.backends.fp8_per_tensor_gemm' %} +{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' or 'modelopt.torch.quantization.backends' not in full_item %} {{ full_item }} {% endif %} {%- endfor %} From 4b7cdd61f2aed44a8e2f0a1fb199556040217da1 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 25 Sep 2025 20:38:57 +0000 Subject: [PATCH 5/6] Fix Signed-off-by: Chenjie Luo --- docs/source/_templates/autosummary/module.rst | 2 +- modelopt/torch/quantization/backends/__init__.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index f6f30fadd..44d00656d 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -11,7 +11,7 @@ :recursive: {% for item in modules %} {% set full_item = fullname + '.' + item.split('.')[-1] %} -{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' or 'modelopt.torch.quantization.backends' not in full_item %} +{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' %} {{ full_item }} {% endif %} {%- endfor %} diff --git a/modelopt/torch/quantization/backends/__init__.py b/modelopt/torch/quantization/backends/__init__.py index a2b690d80..2a92690c6 100644 --- a/modelopt/torch/quantization/backends/__init__.py +++ b/modelopt/torch/quantization/backends/__init__.py @@ -15,9 +15,4 @@ """Quantization backends.""" -import platform - -if platform.system() != "Windows": - from .fp8_per_tensor_gemm import * from .gemm_registry import * -from .nvfp4_gemm import * From b8433efee912a3a66a0c5ca3055136ffd465bb2c Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 26 Sep 2025 02:41:01 +0530 Subject: [PATCH 6/6] Fix fp8_per_tensor_gemm doc build error Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- docs/source/_templates/autosummary/module.rst | 2 +- modelopt/torch/quantization/backends/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index 44d00656d..36eb087fb 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -11,7 +11,7 @@ :recursive: {% for item in modules %} {% set full_item = fullname + '.' + item.split('.')[-1] %} -{% if '.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface' %} +{% if ('.plugins.' not in full_item or full_item == 'modelopt.torch.opt.plugins.huggingface') and full_item != 'modelopt.torch.quantization.backends.fp8_per_tensor_gemm' %} {{ full_item }} {% endif %} {%- endfor %} diff --git a/modelopt/torch/quantization/backends/__init__.py b/modelopt/torch/quantization/backends/__init__.py index 2a92690c6..92317d92b 100644 --- a/modelopt/torch/quantization/backends/__init__.py +++ b/modelopt/torch/quantization/backends/__init__.py @@ -16,3 +16,4 @@ """Quantization backends.""" from .gemm_registry import * +from .nvfp4_gemm import *