-
Notifications
You must be signed in to change notification settings - Fork 532
Support checks PoC #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support checks PoC #1809
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
import math | ||
from enum import Enum | ||
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union | ||
import inspect | ||
|
||
|
||
import torch | ||
import torch.version | ||
|
@@ -60,6 +62,12 @@ class LibraryError(Exception): | |
pass | ||
|
||
|
||
class BackendSupportedError(Exception): | ||
"""Custom exception for backend-related errors.""" | ||
|
||
pass | ||
|
||
|
||
def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: | ||
if x.ndim not in [4, 5]: | ||
raise ValueError("x must be 4D or 5D") | ||
|
@@ -740,3 +748,64 @@ def get_shuffle_matrix_sf_a_row_indices( | |
row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) | ||
|
||
return row_indices | ||
|
||
|
||
def supports_backends( | ||
sricketts marked this conversation as resolved.
Show resolved
Hide resolved
|
||
backends, capabilities=None, anti_capabilities=None, capability_tensor_arg=None | ||
|
||
): | ||
def decorator(func): | ||
# Returns True if backend is supported; with capability, also checks if backend specifically supports it | ||
def is_backend_supported(backend, capability=None): | ||
if backend not in backends: | ||
return False | ||
if capability: | ||
# Anti-capabilities take precedence | ||
if anti_capabilities and backend in anti_capabilities: | ||
if capability in anti_capabilities[backend]: | ||
return False | ||
# Capabilities allow-list | ||
if capabilities and backend in capabilities: | ||
return capability in capabilities[backend] | ||
return True | ||
|
||
# Returns True if any backend supports this capability | ||
def is_supported(capability): | ||
for backend in backends: | ||
if capabilities and backend in capabilities: | ||
if capability in capabilities[backend]: | ||
return True | ||
elif anti_capabilities and backend in anti_capabilities: | ||
if capability in anti_capabilities[backend]: | ||
return False | ||
else: | ||
return True | ||
return False | ||
|
||
def wrapper(*args, **kwargs): | ||
backend = kwargs.get("backend") | ||
capability = None | ||
if capability_tensor_arg: | ||
tensor = kwargs.get(capability_tensor_arg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why we need |
||
# When it wasn't provided as a keyword argument, try to get it from the arguments | ||
if tensor is None: | ||
params = list(inspect.signature(func).parameters) | ||
idx = params.index(capability_tensor_arg) | ||
tensor = args[idx] | ||
if tensor is None: | ||
raise ValueError("Invalid tensor on capability support check") | ||
major, minor = get_compute_capability(tensor.device) | ||
capability = f"{major * 10 + minor}" | ||
if not is_backend_supported(backend, capability): | ||
extra = f" with capability {capability}" if capability else "" | ||
raise BackendSupportedError( | ||
f"{func.__name__} does not support backend '{backend}'{extra}" | ||
) | ||
return func(*args, **kwargs) | ||
|
||
wrapper.is_supported = is_supported | ||
wrapper.is_backend_supported = is_backend_supported | ||
wrapper.__name__ = func.__name__ | ||
wrapper.__doc__ = func.__doc__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can use functools.wraps for more robust wrapping and standardized interface. |
||
return wrapper | ||
|
||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this redundant with the declaration of the
backend
parameter?backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",