Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_sm120a_supported,
is_sm121a_supported,
LibraryError,
supports_backends,
)

CUDNN_AVAILABLE = False
Expand Down Expand Up @@ -1999,6 +2000,11 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size):
return (tuple(block_scale_shape), tuple(block_scale_stride))


@supports_backends(
["cudnn", "trtllm", "cutlass"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this redundant with the declaration of the backend parameter?

backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",

anti_capabilities={"trtllm": ["110"]},
capability_tensor_arg="a",
)
def mm_fp4(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -2113,8 +2119,6 @@ def mm_fp4(
raise ValueError("mxfp4 supports block_size = 32.")
if backend != "trtllm" and use_8x4_sf_layout:
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
if backend == "trtllm" and _match_sm_version(a.device, ["110"]):
raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
if (
Expand Down
69 changes: 69 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import math
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
import inspect


import torch
import torch.version
Expand Down Expand Up @@ -60,6 +62,12 @@ class LibraryError(Exception):
pass


class BackendSupportedError(Exception):
"""Custom exception for backend-related errors."""

pass


def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:
if x.ndim not in [4, 5]:
raise ValueError("x must be 4D or 5D")
Expand Down Expand Up @@ -740,3 +748,64 @@ def get_shuffle_matrix_sf_a_row_indices(
row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)

return row_indices


def supports_backends(
backends, capabilities=None, anti_capabilities=None, capability_tensor_arg=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if "cc" or "compute_capabilities" would be more clear than "capabilities" -- or did you mean to signal something more generic than compute capabilities?

):
def decorator(func):
# Returns True if backend is supported; with capability, also checks if backend specifically supports it
def is_backend_supported(backend, capability=None):
if backend not in backends:
return False
if capability:
# Anti-capabilities take precedence
if anti_capabilities and backend in anti_capabilities:
if capability in anti_capabilities[backend]:
return False
# Capabilities allow-list
if capabilities and backend in capabilities:
return capability in capabilities[backend]
return True

# Returns True if any backend supports this capability
def is_supported(capability):
for backend in backends:
if capabilities and backend in capabilities:
if capability in capabilities[backend]:
return True
elif anti_capabilities and backend in anti_capabilities:
if capability in anti_capabilities[backend]:
return False
else:
return True
return False

def wrapper(*args, **kwargs):
backend = kwargs.get("backend")
capability = None
if capability_tensor_arg:
tensor = kwargs.get(capability_tensor_arg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we need capability_tensor_arg instead of finding torch.Tensors automatically and get the capability from them? We can also assert that they're all on the same device.

# When it wasn't provided as a keyword argument, try to get it from the arguments
if tensor is None:
params = list(inspect.signature(func).parameters)
idx = params.index(capability_tensor_arg)
tensor = args[idx]
if tensor is None:
raise ValueError("Invalid tensor on capability support check")
major, minor = get_compute_capability(tensor.device)
capability = f"{major * 10 + minor}"
if not is_backend_supported(backend, capability):
extra = f" with capability {capability}" if capability else ""
raise BackendSupportedError(
f"{func.__name__} does not support backend '{backend}'{extra}"
)
return func(*args, **kwargs)

wrapper.is_supported = is_supported
wrapper.is_backend_supported = is_backend_supported
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use functools.wraps for more robust wrapping and standardized interface.

return wrapper

return decorator