Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,38 @@
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max


def fp8_per_tensor_gemm(quant_module, input, bias=None):
"""GEMM function for fp8 per tensor quantization."""
@torch.compile(dynamic=True)
def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)


@torch.compile(dynamic=True)
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
input_shape = input.shape
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1])
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
output = torch._scaled_mm(
input_fp8,
weight_fp8_t,
scale_a=scale_a,
scale_b=scale_b,
bias=bias,
out_dtype=input.dtype,
use_fast_accum=True,
)
return output.reshape(*input_shape[:-1], output.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, could you share why moving these functions to the module level reduces the CPU overheads?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. My suspicion is that if it's not in the module level, the torch.compile will be called everytime showed in the nsys trace.


@torch.compile(dynamic=True)
def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)

@torch.compile(dynamic=True)
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
input_shape = input.shape
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1])
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
output = torch._scaled_mm(
input_fp8,
weight_fp8_t,
scale_a=scale_a,
scale_b=scale_b,
bias=bias,
out_dtype=input.dtype,
use_fast_accum=True,
)
return output.reshape(*input_shape[:-1], output.shape[-1])

def fp8_per_tensor_gemm(quant_module, input, bias=None):
"""GEMM function for fp8 per tensor quantization."""
cached_scale_a = (
hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None
)

if not cached_scale_a:
input_amax = quant_module.input_quantizer.amax or reduce_amax(input)
input_amax = quant_module.input_quantizer.amax
if input_amax is None:
input_amax = reduce_amax(input)
assert input_amax != 0
quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device)

Expand All @@ -69,7 +72,9 @@ def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
)

if not cached_scale_b:
weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight)
weight_amax = quant_module.weight_quantizer.amax
if weight_amax is None:
weight_amax = reduce_amax(quant_module.weight)
assert weight_amax != 0
quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device)

Expand Down Expand Up @@ -146,9 +151,9 @@ def forward(
ctx.save_for_backward(
input_tensor if weight.requires_grad else None,
weight if input_tensor.requires_grad else None,
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None,
getattr(quant_module.weight_quantizer, "_scale", None),
)
ctx.compute_bias_grad = bias is not None and bias.requires_grad
ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None)

ctx.allreduce_dgrad = allreduce_dgrad
Expand All @@ -166,7 +171,7 @@ def backward(ctx, grad_outputs):
dequantize it to compute the input gradient. If the weight is not compressed, we will save
the unquantized weight and use it directly to compute the input gradient.
"""
input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors
input_tensor, weight, scale = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if weight is not None:
if isinstance(weight, QTensorWrapper):
Expand All @@ -175,8 +180,10 @@ def backward(ctx, grad_outputs):
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
if compute_bias_grad is not None:
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if ctx.compute_bias_grad:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))

Expand Down
10 changes: 6 additions & 4 deletions modelopt/torch/quantization/backends/nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ def forward(
ctx.save_for_backward(
input_tensor if weight.requires_grad else None,
weight if input_tensor.requires_grad else None,
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None,
getattr(quant_module.weight_quantizer, "_scale", None),
getattr(quant_module.weight_quantizer, "_double_scale", None),
)

ctx.compute_bias_grad = bias is not None and bias.requires_grad
ctx.allreduce_dgrad = allreduce_dgrad
ctx.tp_group = tp_group
ret = nvfp4_gemm(quant_module, input_tensor, bias)
Expand All @@ -158,7 +158,7 @@ def backward(ctx, grad_outputs):
dequantize it to compute the input gradient. If the weight is not compressed, we will save
the unquantized weight and use it directly to compute the input gradient.
"""
input_tensor, weight, compute_bias_grad, scale, double_scale = ctx.saved_tensors
input_tensor, weight, scale, double_scale = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if weight is not None:
if isinstance(weight, QTensorWrapper):
Expand All @@ -173,8 +173,10 @@ def backward(ctx, grad_outputs):
)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, -1) @ input_tensor
if compute_bias_grad is not None:
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if ctx.compute_bias_grad is not None:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))

Expand Down
23 changes: 11 additions & 12 deletions modelopt/torch/quantization/nn/modules/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _should_run_real_quant_gemm(self):
and self.allow_real_quant_gemm
)

def get_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool:
def has_real_quant_gemm_impl(self, input, *args, **kwargs) -> bool:
"""Get the real quant GEMM implementation base on input arguments."""
if not hasattr(self, "_real_quant_gemm_impl"):
self._real_quant_gemm_impl = backends.gemm_registry.find_match(
Expand All @@ -166,20 +166,19 @@ def forward(self, input, *args, **kwargs):
return super().forward(input, *args, **kwargs)

# Check if real-quant GEMM is available
if self._should_run_real_quant_gemm and input.numel() > 1:
# If the input is not quantized, we use the default GEMM.
self.get_real_quant_gemm_impl(input, *args, **kwargs)

if (
self._should_run_real_quant_gemm
and input.numel() > 1
and self.has_real_quant_gemm_impl(input, *args, **kwargs)
):
# Note: We cache the real-quant GEMM function to avoid matching overhead.
# This assumes that the function will not change after the first call.
if self._real_quant_gemm_impl:
with torch.cuda.nvtx.range("RealQuantLinear gemm"):
output = self._real_quant_gemm_impl(
self, input, self.weight, self.bias, *args, **kwargs
)
return (
self.output_quantizer(output) if hasattr(self, "output_quantizer") else output
assert self._real_quant_gemm_impl is not None
with torch.cuda.nvtx.range("RealQuantLinear gemm"):
output = self._real_quant_gemm_impl(
self, input, self.weight, self.bias, *args, **kwargs
)
return self.output_quantizer(output) if hasattr(self, "output_quantizer") else output

# Otherwise, fallback to the default GEMM
return super().forward(input, *args, **kwargs)
Expand Down
8 changes: 5 additions & 3 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def real_quant_module_set_extra_state(self, state: Any):
"""
q_tensor_state = state.get("modelopt_q_tensor_state", None)

if q_tensor_state is not None:
if q_tensor_state:
q_tensor_metadata = q_tensor_state["metadata"]
q_tensor_metadata["shape"] = self.weight.shape
q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"]
Expand Down Expand Up @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
input, *args, **kwargs
if (
self._should_run_real_quant_gemm
and input.numel() > 1
and self.has_real_quant_gemm_impl(input, *args, **kwargs)
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
Expand Down
Loading