diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index 1b61864c0..8f232163e 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -24,6 +24,7 @@ from modelopt.torch.quantization.config import FP8_DEFAULT_CFG from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import FP8QTensor, QTensorWrapper +from modelopt.torch.quantization.triton.fp8_kernel import fp8_gemm from modelopt.torch.quantization.utils import reduce_amax from .utils import fp8_compatible @@ -32,62 +33,37 @@ FP8_MAX = torch.finfo(torch.float8_e4m3fn).max -def fp8_per_tensor_gemm(quant_module, input, bias=None): - """GEMM function for fp8 per tensor quantization.""" +@torch.compile(dynamic=True) +def _to_fp8(x, amax): + return (x.to(torch.float32) / amax * 448.0).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) - @torch.compile(dynamic=True) - def _to_fp8(x, scale): - return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) - - @torch.compile(dynamic=True) - def _fp8_gemm_impl(input, weight_fp8, scale_a, scale_b, bias=None): - input_shape = input.shape - input_fp8 = _to_fp8(input, scale_a).reshape(-1, input_shape[-1]) - weight_fp8_t = weight_fp8.reshape(-1, weight_fp8.shape[-1]).t() - output = torch._scaled_mm( - input_fp8, - weight_fp8_t, - scale_a=scale_a, - scale_b=scale_b, - bias=bias, - out_dtype=input.dtype, - use_fast_accum=True, - ) - return output.reshape(*input_shape[:-1], output.shape[-1]) - cached_scale_a = ( - hasattr(quant_module, "_scale_a") and quant_module.input_quantizer.amax is not None +def fp8_per_tensor_gemm(quant_module, input, bias=None): + """GEMM function for fp8 per tensor quantization.""" + input_amax = ( + quant_module.input_quantizer.amax + if quant_module.input_quantizer.amax is not None + else reduce_amax(input) ) - - if not cached_scale_a: - input_amax = quant_module.input_quantizer.amax or reduce_amax(input) - assert input_amax != 0 - quant_module._scale_a = (input_amax.float() / 448.0).to(device=input.device) - - cached_scale_b = ( - hasattr(quant_module, "_scale_b") and quant_module.weight_quantizer.amax is not None + weight_amax = ( + quant_module.weight_quantizer.amax + if quant_module.weight_quantizer.amax is not None + else reduce_amax(quant_module.weight) ) - if not cached_scale_b: - weight_amax = quant_module.weight_quantizer.amax or reduce_amax(quant_module.weight) - assert weight_amax != 0 - quant_module._scale_b = (weight_amax.float() / 448.0).to(device=quant_module.weight.device) - if quant_module.weight.dtype != torch.float8_e4m3fn: - weight_fp8 = _to_fp8(quant_module.weight, quant_module._scale_b) + with torch.cuda.nvtx.range("compress weight"): + weight_fp8 = _to_fp8(quant_module.weight.data, weight_amax) else: weight_fp8 = quant_module.weight.data - output = _fp8_gemm_impl( + output = fp8_gemm( input, weight_fp8, - scale_a=quant_module._scale_a, - scale_b=quant_module._scale_b, - bias=bias if input.dtype != torch.float32 else None, + input_amax, + weight_amax, + bias=bias, ) - # _scaled_mm does not support bias for float32 input, so we add it manually - if input.dtype == torch.float32 and bias is not None: - output += bias return output @@ -146,16 +122,16 @@ def forward( ctx.save_for_backward( input_tensor if weight.requires_grad else None, weight if input_tensor.requires_grad else None, - torch.empty(0, dtype=torch.uint8) if bias is not None and bias.requires_grad else None, getattr(quant_module.weight_quantizer, "_scale", None), ) + + ctx.compute_bias_grad = bias is not None and bias.requires_grad ctx.block_sizes = getattr(quant_module.weight_quantizer, "_block_sizes", None) ctx.allreduce_dgrad = allreduce_dgrad ctx.tp_group = tp_group - ret = fp8_per_tensor_gemm(quant_module, input_tensor, bias) - return ret + return fp8_per_tensor_gemm(quant_module, input_tensor, bias) @staticmethod def backward(ctx, grad_outputs): @@ -166,7 +142,7 @@ def backward(ctx, grad_outputs): dequantize it to compute the input gradient. If the weight is not compressed, we will save the unquantized weight and use it directly to compute the input gradient. """ - input_tensor, weight, compute_bias_grad, scale = ctx.saved_tensors + input_tensor, weight, scale = ctx.saved_tensors grad_input = grad_weight = grad_bias = None if weight is not None: if isinstance(weight, QTensorWrapper): @@ -175,8 +151,10 @@ def backward(ctx, grad_outputs): weight = weight.dequantize(scale=scale, block_sizes=ctx.block_sizes) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, 1) @ input_tensor - if compute_bias_grad is not None: + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) + if ctx.compute_bias_grad: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index 4a67a3b93..e637a3e4a 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -173,7 +173,9 @@ def backward(ctx, grad_outputs): ) grad_input = grad_outputs @ weight if input_tensor is not None: - grad_weight = grad_outputs.transpose(-2, -1) @ input_tensor + grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( + -1, input_tensor.shape[-1] + ) if compute_bias_grad is not None: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) diff --git a/modelopt/torch/quantization/nn/modules/quant_linear.py b/modelopt/torch/quantization/nn/modules/quant_linear.py index 584385d0a..badff878b 100644 --- a/modelopt/torch/quantization/nn/modules/quant_linear.py +++ b/modelopt/torch/quantization/nn/modules/quant_linear.py @@ -168,11 +168,9 @@ def forward(self, input, *args, **kwargs): # Check if real-quant GEMM is available if self._should_run_real_quant_gemm and input.numel() > 1: # If the input is not quantized, we use the default GEMM. - self.get_real_quant_gemm_impl(input, *args, **kwargs) - # Note: We cache the real-quant GEMM function to avoid matching overhead. # This assumes that the function will not change after the first call. - if self._real_quant_gemm_impl: + if self.get_real_quant_gemm_impl(input, *args, **kwargs): with torch.cuda.nvtx.range("RealQuantLinear gemm"): output = self._real_quant_gemm_impl( self, input, self.weight, self.bias, *args, **kwargs diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..bf7f5152b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -115,7 +115,7 @@ def real_quant_module_set_extra_state(self, state: Any): """ q_tensor_state = state.get("modelopt_q_tensor_state", None) - if q_tensor_state is not None: + if q_tensor_state: q_tensor_metadata = q_tensor_state["metadata"] q_tensor_metadata["shape"] = self.weight.shape q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"] @@ -418,8 +418,10 @@ class forward(). This is not desired since _forward_impl introduces much more ar while the original forward only takes 1 positional argument. We must above the fallback path in RealQuantLinear.forward(). """ - if self._should_run_real_quant_gemm and self.get_real_quant_gemm_impl( - input, *args, **kwargs + if ( + self._should_run_real_quant_gemm + and self.get_real_quant_gemm_impl(input, *args, **kwargs) + and input.numel() > 1 ): allreduce_dgrad = kwargs.get("allreduce_dgrad", False) tp_group = kwargs.get("tp_group") diff --git a/modelopt/torch/quantization/triton/fp8_kernel.py b/modelopt/torch/quantization/triton/fp8_kernel.py new file mode 100644 index 000000000..0cef78129 --- /dev/null +++ b/modelopt/torch/quantization/triton/fp8_kernel.py @@ -0,0 +1,115 @@ +import torch +import triton +import triton.language as tl +from triton import Config + +fp8_gemm_configs = [ + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] +] + + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) +@triton.jit +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_amax_ptr, + b_amax_ptr, + bias_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + bias_ptr (tl.tensor): Pointer to the bias tensor. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_amax = tl.load(a_amax_ptr) + b_amax = tl.load(b_amax_ptr) + + a_scale = a_amax / 448.0 + b_scale = b_amax / 448.0 + + output_dtype = c_ptr.dtype.element_ty + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.bfloat16) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + a_fp8 = tl.clamp((a / a_scale), -448.0, 448.0).to(tl.float8e4nv) + b_fp8 = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a_fp8, b_fp8).to(tl.bfloat16) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + c = (accumulator * a_scale * b_scale).to(output_dtype) + if bias_ptr is not None: + bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0) + c = c + bias + # offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def fp8_gemm( + a: torch.Tensor, + b_fp8: torch.Tensor, + a_amax: torch.Tensor, + b_amax: torch.Tensor, + bias: torch.Tensor | None = None, +): + """Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + a_amax (torch.Tensor): The amax for the first input matrix, must be a scalar. + b_amax (torch.Tensor): The amax for the second input matrix, must be a scalar. + bias (torch.Tensor | None): The bias tensor, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b_fp8.is_contiguous(), "Input tensors must be contiguous" + K = a.size(-1) + M = a.numel() // K + N = b_fp8.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=a.dtype) + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"])) + fp8_gemm_kernel[grid](a, b_fp8, c, a_amax, b_amax, bias, M, N, K) + + return c diff --git a/pyproject.toml b/pyproject.toml index 8ae14292d..302f072f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ disable_error_code = ["attr-defined"] # Default additional options # Show a short test summary info for all except passed tests with -ra flag # print execution time for 20 slowest tests and generate coverage reports -addopts = "-ra --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers" +# addopts = "-ra --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers" pythonpath = ["tests/"] markers = ["manual: Only run when --run-manual is given"] diff --git a/tests/gpu/torch/quantization/triton/test_fp8_kernel.py b/tests/gpu/torch/quantization/triton/test_fp8_kernel.py new file mode 100644 index 000000000..e7e5dc076 --- /dev/null +++ b/tests/gpu/torch/quantization/triton/test_fp8_kernel.py @@ -0,0 +1,58 @@ +import pytest +import torch +from _test_utils.torch_misc import set_seed + +from modelopt.torch.quantization.triton.fp8_kernel import fp8_gemm + + +@pytest.fixture(autouse=True) +def setup_seed(): + """Set seed before each test function.""" + set_seed() + + +@pytest.mark.parametrize( + "M,N,K,dtype,with_bias", + [ + (16, 16, 16, torch.float16, False), + (32, 32, 32, torch.bfloat16, False), + (16, 32, 48, torch.float16, False), + (48, 32, 16, torch.bfloat16, False), + (16, 16, 16, torch.float16, True), + (32, 32, 32, torch.bfloat16, True), + ], +) +def test_fp8_gemm_basic(M, N, K, dtype, with_bias): + # Create random input matrices + a = torch.randn(M, K, dtype=dtype, device="cuda") + b = torch.randn(N, K, dtype=dtype, device="cuda") + # amax for scaling (simulate FP8 quantization) + a_amax = a.abs().max() + b_amax = b.abs().max() + + # Reference: simulate quantization/dequantization and matmul + # Quantize to FP8 (simulate with clamping and scaling) + a_scale = a_amax.to(torch.float32) / 448.0 + b_scale = b_amax.to(torch.float32) / 448.0 + a_fp8 = torch.clamp((a.to(torch.float32) / a_scale), -448.0, 448.0) + b_fp8 = torch.clamp((b.to(torch.float32) / b_scale), -448.0, 448.0) + + bias = None if not with_bias else torch.randn(N, dtype=dtype, device="cuda") + + # Run fp8_gemm + actual = fp8_gemm(a, b_fp8.to(torch.float8_e4m3fn), a_amax, b_amax, bias=bias).to(torch.float32) + + ref = torch._scaled_mm( + a_fp8.to(torch.float8_e4m3fn), + b_fp8.to(torch.float8_e4m3fn).T, + a_scale, + b_scale, + use_fast_accum=True, + out_dtype=a.dtype, + bias=bias, + ).to(torch.float32) + + # Compare + assert actual.shape == (M, N) + # Allow some tolerance due to quantization error + torch.testing.assert_close(actual, ref, atol=1e-1, rtol=1e-1)