Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 28 additions & 50 deletions modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from modelopt.torch.quantization.config import FP8_DEFAULT_CFG
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import FP8QTensor, QTensorWrapper
from modelopt.torch.quantization.triton.fp8_kernel import fp8_gemm
from modelopt.torch.quantization.utils import reduce_amax

from .utils import fp8_compatible
Expand All @@ -32,62 +33,37 @@
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max


def fp8_per_tensor_gemm(quant_module, input, bias=None):
"""GEMM function for fp8 per tensor quantization."""
@torch.compile(dynamic=True)
def _to_fp8(x, amax):
return (x.to(torch.float32) / amax * 448.0).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)

@torch.compile(dynamic=True)
def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)

@torch.compile(dynamic=True)
def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None):
input_shape = input.shape
input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1])
weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t()
output = torch._scaled_mm(
input_fp8,
weight_fp8_t,
scale_a=scale_a,
scale_b=scale_b,
bias=bias,
out_dtype=input.dtype,
use_fast_accum=True,
)
return output.reshape(*input_shape[:-1], output.shape[-1])

cached_scale_a = (
hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None
def fp8_per_tensor_gemm(quant_module, input, bias=None):
"""GEMM function for fp8 per tensor quantization."""
input_amax = (
quant_module.input_quantizer.amax
if quant_module.input_quantizer.amax is not None
else reduce_amax(input)
)

if not cached_scale_a:
input_amax = quant_module.input_quantizer.amax or reduce_amax(input)
assert input_amax != 0
quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device)

cached_scale_b = (
hasattr(quant_module, "_scale_b") and quant_module.weight_quantizer.amax is not None
weight_amax = (
quant_module.weight_quantizer.amax
if quant_module.weight_quantizer.amax is not None
else reduce_amax(quant_module.weight)
)

if not cached_scale_b:
weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight)
assert weight_amax != 0
quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device)

if quant_module.weight.dtype != torch.float8_e4m3fn:
weight_fp8 = _to_fp8(quant_module.weight, quant_module._scale_b)
with torch.cuda.nvtx.range("compress weight"):
weight_fp8 = _to_fp8(quant_module.weight.data, weight_amax)
else:
weight_fp8 = quant_module.weight.data

output = _fp8_gemm_impl(
output = fp8_gemm(
input,
weight_fp8,
scale_a=quant_module._scale_a,
scale_b=quant_module._scale_b,
bias=bias if input.dtype != torch.float32 else None,
input_amax,
weight_amax,
bias=bias,
)
# _scaled_mm does not support bias for float32 input, so we add it manually
if input.dtype == torch.float32 and bias is not None:
output += bias
return output


Expand Down Expand Up @@ -146,16 +122,16 @@ def forward(
ctx.save_for_backward(
input_tensor if weight.requires_grad else None,
weight if input_tensor.requires_grad else None,
torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None,
getattr(quant_module.weight_quantizer, "_scale", None),
)

ctx.compute_bias_grad = bias is not None and bias.requires_grad
ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None)

ctx.allreduce_dgrad = allreduce_dgrad
ctx.tp_group = tp_group

ret = fp8_per_tensor_gemm(quant_module, input_tensor, bias)
return ret
return fp8_per_tensor_gemm(quant_module, input_tensor, bias)

@staticmethod
def backward(ctx, grad_outputs):
Expand All @@ -166,7 +142,7 @@ def backward(ctx, grad_outputs):
dequantize it to compute the input gradient. If the weight is not compressed, we will save
the unquantized weight and use it directly to compute the input gradient.
"""
input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors
input_tensor, weight, scale = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if weight is not None:
if isinstance(weight, QTensorWrapper):
Expand All @@ -175,8 +151,10 @@ def backward(ctx, grad_outputs):
weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor
if compute_bias_grad is not None:
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if ctx.compute_bias_grad:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))

Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/quantization/backends/nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def backward(ctx, grad_outputs):
)
grad_input = grad_outputs @ weight
if input_tensor is not None:
grad_weight = grad_outputs.transpose(-2, -1) @ input_tensor
grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-1, input_tensor.shape[-1]
)
if compute_bias_grad is not None:
# Sum all dimensions except the last one
grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))
Expand Down
4 changes: 1 addition & 3 deletions modelopt/torch/quantization/nn/modules/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,9 @@ def forward(self, input, *args, **kwargs):
# Check if real-quant GEMM is available
if self._should_run_real_quant_gemm and input.numel() > 1:
# If the input is not quantized, we use the default GEMM.
self.get_real_quant_gemm_impl(input, *args, **kwargs)

# Note: We cache the real-quant GEMM function to avoid matching overhead.
# This assumes that the function will not change after the first call.
if self._real_quant_gemm_impl:
if self.get_real_quant_gemm_impl(input, *args, **kwargs):
with torch.cuda.nvtx.range("RealQuantLinear gemm"):
output = self._real_quant_gemm_impl(
self, input, self.weight, self.bias, *args, **kwargs
Expand Down
8 changes: 5 additions & 3 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def real_quant_module_set_extra_state(self, state: Any):
"""
q_tensor_state = state.get("modelopt_q_tensor_state", None)

if q_tensor_state is not None:
if q_tensor_state:
q_tensor_metadata = q_tensor_state["metadata"]
q_tensor_metadata["shape"] = self.weight.shape
q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"]
Expand Down Expand Up @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl(
input, *args, **kwargs
if (
self._should_run_real_quant_gemm
and self.get_real_quant_gemm_impl(input, *args, **kwargs)
and input.numel() > 1
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
Expand Down
115 changes: 115 additions & 0 deletions modelopt/torch/quantization/triton/fp8_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import triton
import triton.language as tl
from triton import Config

fp8_gemm_configs = [
Config(
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]


@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit
def fp8_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
a_amax_ptr,
b_amax_ptr,
bias_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""Performs a matrix multiplication operation on FP8 matrices with scaling factors.

Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
bias_ptr (tl.tensor): Pointer to the bias tensor.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.

Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_amax = tl.load(a_amax_ptr)
b_amax = tl.load(b_amax_ptr)

a_scale = a_amax / 448.0
b_scale = b_amax / 448.0

output_dtype = c_ptr.dtype.element_ty
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.bfloat16)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
a_fp8 = tl.clamp((a / a_scale), -448.0, 448.0).to(tl.float8e4nv)
b_fp8 = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
accumulator += tl.dot(a_fp8, b_fp8).to(tl.bfloat16)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K

c = (accumulator * a_scale * b_scale).to(output_dtype)
if bias_ptr is not None:
bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0)
c = c + bias
# offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
# offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)


def fp8_gemm(
a: torch.Tensor,
b_fp8: torch.Tensor,
a_amax: torch.Tensor,
b_amax: torch.Tensor,
bias: torch.Tensor | None = None,
):
"""Perform a matrix multiplication using FP8 precision.

Args:
a (torch.Tensor): The first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
a_amax (torch.Tensor): The amax for the first input matrix, must be a scalar.
b_amax (torch.Tensor): The amax for the second input matrix, must be a scalar.
bias (torch.Tensor | None): The bias tensor, must be contiguous.

Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b_fp8.is_contiguous(), "Input tensors must be contiguous"
K = a.size(-1)
M = a.numel() // K
N = b_fp8.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=a.dtype)
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]))
fp8_gemm_kernel[grid](a, b_fp8, c, a_amax, b_amax, bias, M, N, K)

return c
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ disable_error_code = ["attr-defined"]
# Default additional options
# Show a short test summary info for all except passed tests with -ra flag
# print execution time for 20 slowest tests and generate coverage reports
addopts = "-ra --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers"
# addopts = "-ra --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers"
pythonpath = ["tests/"]
markers = ["manual: Only run when --run-manual is given"]

Expand Down
58 changes: 58 additions & 0 deletions tests/gpu/torch/quantization/triton/test_fp8_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import torch
from _test_utils.torch_misc import set_seed

from modelopt.torch.quantization.triton.fp8_kernel import fp8_gemm


@pytest.fixture(autouse=True)
def setup_seed():
"""Set seed before each test function."""
set_seed()


@pytest.mark.parametrize(
"M,N,K,dtype,with_bias",
[
(16, 16, 16, torch.float16, False),
(32, 32, 32, torch.bfloat16, False),
(16, 32, 48, torch.float16, False),
(48, 32, 16, torch.bfloat16, False),
(16, 16, 16, torch.float16, True),
(32, 32, 32, torch.bfloat16, True),
],
)
def test_fp8_gemm_basic(M, N, K, dtype, with_bias):
# Create random input matrices
a = torch.randn(M, K, dtype=dtype, device="cuda")
b = torch.randn(N, K, dtype=dtype, device="cuda")
# amax for scaling (simulate FP8 quantization)
a_amax = a.abs().max()
b_amax = b.abs().max()

# Reference: simulate quantization/dequantization and matmul
# Quantize to FP8 (simulate with clamping and scaling)
a_scale = a_amax.to(torch.float32) / 448.0
b_scale = b_amax.to(torch.float32) / 448.0
a_fp8 = torch.clamp((a.to(torch.float32) / a_scale), -448.0, 448.0)
b_fp8 = torch.clamp((b.to(torch.float32) / b_scale), -448.0, 448.0)

bias = None if not with_bias else torch.randn(N, dtype=dtype, device="cuda")

# Run fp8_gemm
actual = fp8_gemm(a, b_fp8.to(torch.float8_e4m3fn), a_amax, b_amax, bias=bias).to(torch.float32)

ref = torch._scaled_mm(
a_fp8.to(torch.float8_e4m3fn),
b_fp8.to(torch.float8_e4m3fn).T,
a_scale,
b_scale,
use_fast_accum=True,
out_dtype=a.dtype,
bias=bias,
).to(torch.float32)

# Compare
assert actual.shape == (M, N)
# Allow some tolerance due to quantization error
torch.testing.assert_close(actual, ref, atol=1e-1, rtol=1e-1)
Loading