Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 163 additions & 65 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_sm120a_supported,
is_sm121a_supported,
LibraryError,
supports_backends,
)

CUDNN_AVAILABLE = False
Expand Down Expand Up @@ -1644,7 +1645,7 @@ def _validate_fp8_output_dtype(dtype: torch.dtype):


@functools.cache
def build_cudnn_gemm_block_scale_dequantize_graph(
def create_cudnn_execution_plans_fp4_gemm(
a_shape,
a_stride,
b_shape,
Expand Down Expand Up @@ -1741,12 +1742,49 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
# in older cuDNN versions, so we deselect it.
if (alpha is not None) and (not _is_cublas_fp4_available_in_cudnn()):
graph.deselect_engines(["eng0"])
graph.check_support()
graph.build_plans()

return graph


@functools.cache
def build_plans_cudnn_fp4_gemm_graph(
a_shape,
a_stride,
b_shape,
b_stride,
a_descale_shape,
a_descale_stride,
b_descale_shape,
b_descale_stride,
ab_type,
o_type,
block_size,
device,
alpha,
use_nvfp4,
):
graph = create_cudnn_execution_plans_fp4_gemm(
a_shape,
a_stride,
b_shape,
b_stride,
a_descale_shape,
a_descale_stride,
b_descale_shape,
b_descale_stride,
ab_type,
o_type,
block_size,
device,
alpha,
use_nvfp4,
)

graph.check_support()
graph.build_plans()
return graph


def execute_cudnn_gemm_fp4_graph(
graph,
a,
Expand Down Expand Up @@ -1999,6 +2037,127 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
return (tuple(block_scale_shape), tuple(block_scale_stride))


def _check_mm_fp4_backend_supported(
a: torch.Tensor,
b: torch.Tensor,
a_descale: torch.Tensor,
b_descale: torch.Tensor,
alpha: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
out: Optional[torch.Tensor] = None,
block_size: int = 16,
use_8x4_sf_layout: bool = False,
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
use_nvfp4: bool = True,
):
# Generic checks
## pre-check the input tensor, block scale tensor and alpha tensor
if a.ndim != 2 or b.ndim != 2:
raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
if a.shape[1] != b.shape[0]:
raise ValueError(
f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}"
)
if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in {
torch.uint8,
_get_native_fp4_dtype(),
}:
raise ValueError(
f"a and b must have float4_e2m1fn_x2 packed into uint8. "
f"Got {a.dtype} and {b.dtype}."
)
if a_descale.dtype not in {
torch.float8_e4m3fn,
torch.uint8,
} or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}:
raise ValueError(
f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
f"Got {a_descale.dtype} and {b_descale.dtype}."
)
if alpha is not None and alpha.dtype != torch.float:
raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}")
if alpha is not None and alpha.numel() != 1:
raise ValueError(f"alpha must be a scalar, got {alpha.numel()}")

if out_dtype not in (torch.bfloat16, torch.float16):
raise ValueError(
f"Unsupported output dtype: {out_dtype}. "
f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
)

if use_nvfp4 and block_size != 16:
raise ValueError("nvfp4 only supports block_size = 16.")
if not use_nvfp4 and block_size != 32:
raise ValueError("mxfp4 supports block_size = 32.")

if backend != "trtllm" and use_8x4_sf_layout:
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")

# Backend specific checks
if backend == "cudnn":
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)

_check_cudnn_fp4_availability()

# the fp4 cudnn graph will be shared for both mm and bmm, so
# here we need to get the 3d shape and stride including the
# batch dimension for both input and block scale tensors.
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
batch = real_a_shape[0]
expanded_a_descale_shape, expanded_a_descale_stride = (
_expand_block_scale_tensor_shape(a_descale, batch)
)
expanded_b_descale_shape, expanded_b_descale_stride = (
_expand_block_scale_tensor_shape(b_descale, batch)
)

# build the fp4 cudnn graph
graph = create_cudnn_execution_plans_fp4_gemm(
real_a_shape,
real_a_stride,
real_b_shape,
real_b_stride,
expanded_a_descale_shape,
expanded_a_descale_stride,
expanded_b_descale_shape,
expanded_b_descale_stride,
cudnn.data_type.FP4_E2M1,
_torch_data_type_to_cudnn_data_type(out_dtype),
block_size,
a.device,
alpha,
use_nvfp4,
)
graph.check_support()

elif backend == "trtllm":
if out_dtype != torch.bfloat16:
raise ValueError(
f"Unsupported output dtype: {out_dtype}. "
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
)
elif backend == "cutlass":
# No additional checks for cutlass
pass
return True


@supports_backends(
["cudnn", "trtllm", "cutlass"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this redundant with the declaration of the backend parameter?

backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",

anti_capabilities={"trtllm": ["110"]},
capability_tensor_arg="a",
problem_size_check=_check_mm_fp4_backend_supported,
)
def mm_fp4(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -2073,59 +2232,6 @@ def mm_fp4(
>>> out.shape
torch.Size([48, 256])
"""
# pre-check the input tensor, block scale tensor and alpha tensor
if a.ndim != 2 or b.ndim != 2:
raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
if a.shape[1] != b.shape[0]:
raise ValueError(
f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}"
)
if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in {
torch.uint8,
_get_native_fp4_dtype(),
}:
raise ValueError(
f"a and b must have float4_e2m1fn_x2 packed into uint8. "
f"Got {a.dtype} and {b.dtype}."
)
if a_descale.dtype not in {
torch.float8_e4m3fn,
torch.uint8,
} or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}:
raise ValueError(
f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
f"Got {a_descale.dtype} and {b_descale.dtype}."
)
if alpha is not None and alpha.dtype != torch.float:
raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}")
if alpha is not None and alpha.numel() != 1:
raise ValueError(f"alpha must be a scalar, got {alpha.numel()}")

if out_dtype not in (torch.bfloat16, torch.float16):
raise ValueError(
f"Unsupported output dtype: {out_dtype}. "
f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
)

if use_nvfp4 and block_size != 16:
raise ValueError("nvfp4 only supports block_size = 16.")
if not use_nvfp4 and block_size != 32:
raise ValueError("mxfp4 supports block_size = 32.")
if backend != "trtllm" and use_8x4_sf_layout:
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
if backend == "trtllm" and _match_sm_version(a.device, ["110"]):
raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
if (
backend == "cudnn"
and not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)

# allocate the output tensor if not provided
if out is None:
Expand All @@ -2140,8 +2246,6 @@ def mm_fp4(
)

if backend == "cudnn":
_check_cudnn_fp4_availability()

# the fp4 cudnn graph will be shared for both mm and bmm, so
# here we need to get the 3d shape and stride including the
# batch dimension for both input and block scale tensors.
Expand All @@ -2156,7 +2260,7 @@ def mm_fp4(
)

# build the fp4 cudnn graph
graph = build_cudnn_gemm_block_scale_dequantize_graph(
graph = build_plans_cudnn_fp4_gemm_graph(
real_a_shape,
real_a_stride,
real_b_shape,
Expand All @@ -2178,12 +2282,6 @@ def mm_fp4(
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer
)
elif backend == "trtllm":
if out_dtype != torch.bfloat16:
raise ValueError(
f"Unsupported output dtype: {out_dtype}. "
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
)

get_trtllm_fp4_gemm_module().trtllm_fp4_gemm(
a,
b.T,
Expand Down
Loading