Skip to content

Commit cbdd3e0

Browse files
committed
Improve realquant gemm impl
1 parent 26c203a commit cbdd3e0

File tree

4 files changed

+56
-46
lines changed

4 files changed

+56
-46
lines changed

modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,35 +32,38 @@
3232
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
3333

3434

35-
def fp8_per_tensor_gemm(quant_module, input, bias=None):
36-
"""GEMM function for fp8 per tensor quantization."""
35+
@torch.compile(dynamic=True)
36+
def _to_fp8(x, scale):
37+
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
38+
39+
40+
@torch.compile(dynamic=True)
41+
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
42+
input_shape = input.shape
43+
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1])
44+
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
45+
output = torch._scaled_mm(
46+
input_fp8,
47+
weight_fp8_t,
48+
scale_a=scale_a,
49+
scale_b=scale_b,
50+
bias=bias,
51+
out_dtype=input.dtype,
52+
use_fast_accum=True,
53+
)
54+
return output.reshape(*input_shape[:-1], output.shape[-1])
3755

38-
@torch.compile(dynamic=True)
39-
def _to_fp8(x, scale):
40-
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
41-
42-
@torch.compile(dynamic=True)
43-
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
44-
input_shape = input.shape
45-
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1])
46-
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
47-
output = torch._scaled_mm(
48-
input_fp8,
49-
weight_fp8_t,
50-
scale_a=scale_a,
51-
scale_b=scale_b,
52-
bias=bias,
53-
out_dtype=input.dtype,
54-
use_fast_accum=True,
55-
)
56-
return output.reshape(*input_shape[:-1], output.shape[-1])
5756

57+
def fp8_per_tensor_gemm(quant_module, input, bias=None):
58+
"""GEMM function for fp8 per tensor quantization."""
5859
cached_scale_a = (
5960
hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None
6061
)
6162

6263
if not cached_scale_a:
63-
input_amax = quant_module.input_quantizer.amax or reduce_amax(input)
64+
input_amax = quant_module.input_quantizer.amax
65+
if input_amax is None:
66+
input_amax = reduce_amax(input)
6467
assert input_amax != 0
6568
quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device)
6669

@@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
6972
)
7073

7174
if not cached_scale_b:
72-
weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight)
75+
weight_amax = quant_module.weight_quantizer.amax
76+
if weight_amax is None:
77+
weight_amax = reduce_amax(quant_module.weight)
7378
assert weight_amax != 0
7479
quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device)
7580

@@ -146,9 +151,9 @@ def forward(
146151
ctx.save_for_backward(
147152
input_tensor if weight.requires_grad else None,
148153
weight if input_tensor.requires_grad else None,
149-
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None,
150154
getattr(quant_module.weight_quantizer, "_scale", None),
151155
)
156+
ctx.compute_bias_grad = bias is not None and bias.requires_grad
152157
ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None)
153158

154159
ctx.allreduce_dgrad = allreduce_dgrad
@@ -166,7 +171,7 @@ def backward(ctx, grad_outputs):
166171
dequantize it to compute the input gradient. If the weight is not compressed, we will save
167172
the unquantized weight and use it directly to compute the input gradient.
168173
"""
169-
input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors
174+
input_tensor, weight, scale = ctx.saved_tensors
170175
grad_input = grad_weight = grad_bias = None
171176
if weight is not None:
172177
if isinstance(weight, QTensorWrapper):
@@ -175,8 +180,10 @@ def backward(ctx, grad_outputs):
175180
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
176181
grad_input = grad_outputs @ weight
177182
if input_tensor is not None:
178-
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
179-
if compute_bias_grad is not None:
183+
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
184+
-1, input_tensor.shape[-1]
185+
)
186+
if ctx.compute_bias_grad:
180187
# Sum all dimensions except the last one
181188
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))
182189

modelopt/torch/quantization/backends/nvfp4_gemm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,11 @@ def forward(
139139
ctx.save_for_backward(
140140
input_tensor if weight.requires_grad else None,
141141
weight if input_tensor.requires_grad else None,
142-
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None,
143142
getattr(quant_module.weight_quantizer, "_scale", None),
144143
getattr(quant_module.weight_quantizer, "_double_scale", None),
145144
)
146145

146+
ctx.compute_bias_grad = bias is not None and bias.requires_grad
147147
ctx.allreduce_dgrad = allreduce_dgrad
148148
ctx.tp_group = tp_group
149149
ret = nvfp4_gemm(quant_module, input_tensor, bias)
@@ -158,7 +158,7 @@ def backward(ctx, grad_outputs):
158158
dequantize it to compute the input gradient. If the weight is not compressed, we will save
159159
the unquantized weight and use it directly to compute the input gradient.
160160
"""
161-
input_tensor, weight, compute_bias_grad, scale, double_scale = ctx.saved_tensors
161+
input_tensor, weight, scale, double_scale = ctx.saved_tensors
162162
grad_input = grad_weight = grad_bias = None
163163
if weight is not None:
164164
if isinstance(weight, QTensorWrapper):
@@ -173,8 +173,10 @@ def backward(ctx, grad_outputs):
173173
)
174174
grad_input = grad_outputs @ weight
175175
if input_tensor is not None:
176-
grad_weight = grad_outputs.transpose(-2, -1) @ input_tensor
177-
if compute_bias_grad is not None:
176+
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
177+
-1, input_tensor.shape[-1]
178+
)
179+
if ctx.compute_bias_grad is not None:
178180
# Sum all dimensions except the last one
179181
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))
180182

modelopt/torch/quantization/nn/modules/quant_linear.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _should_run_real_quant_gemm(self):
148148
and self.allow_real_quant_gemm
149149
)
150150

151-
def get_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool:
151+
def has_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool:
152152
"""Get the real quant GEMM implementation base on input arguments."""
153153
if not hasattr(self, "_real_quant_gemm_impl"):
154154
self._real_quant_gemm_impl = backends.gemm_registry.find_match(
@@ -166,20 +166,19 @@ def forward(self, input, *args, **kwargs):
166166
return super().forward(input, *args, **kwargs)
167167

168168
# Check if real-quant GEMM is available
169-
if self._should_run_real_quant_gemm and input.numel() > 1:
170-
# If the input is not quantized, we use the default GEMM.
171-
self.get_real_quant_gemm_impl(input, *args, **kwargs)
172-
169+
if (
170+
self._should_run_real_quant_gemm
171+
and input.numel() > 1
172+
and self.has_real_quant_gemm_impl(input, *args, **kwargs)
173+
):
173174
# Note: We cache the real-quant GEMM function to avoid matching overhead.
174175
# This assumes that the function will not change after the first call.
175-
if self._real_quant_gemm_impl:
176-
with torch.cuda.nvtx.range("RealQuantLinear gemm"):
177-
output = self._real_quant_gemm_impl(
178-
self, input, self.weight, self.bias, *args, **kwargs
179-
)
180-
return (
181-
self.output_quantizer(output) if hasattr(self, "output_quantizer") else output
176+
assert self._real_quant_gemm_impl is not None
177+
with torch.cuda.nvtx.range("RealQuantLinear gemm"):
178+
output = self._real_quant_gemm_impl(
179+
self, input, self.weight, self.bias, *args, **kwargs
182180
)
181+
return self.output_quantizer(output) if hasattr(self, "output_quantizer") else output
183182

184183
# Otherwise, fallback to the default GEMM
185184
return super().forward(input, *args, **kwargs)

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def real_quant_module_set_extra_state(self, state: Any):
115115
"""
116116
q_tensor_state = state.get("modelopt_q_tensor_state", None)
117117

118-
if q_tensor_state is not None:
118+
if q_tensor_state:
119119
q_tensor_metadata = q_tensor_state["metadata"]
120120
q_tensor_metadata["shape"] = self.weight.shape
121121
q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"]
@@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
418418
while the original forward only takes 1 positional argument. We must above the fallback path
419419
in RealQuantLinear.forward().
420420
"""
421-
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
422-
input, *args, **kwargs
421+
if (
422+
self._should_run_real_quant_gemm
423+
and input.numel() > 1
424+
and self.has_real_quant_gemm_impl(input, *args, **kwargs)
423425
):
424426
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
425427
tp_group = kwargs.get("tp_group")

0 commit comments

Comments
 (0)