Skip to content
3 changes: 3 additions & 0 deletions src/natten/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
#
#################################################################################################


from natten._libnatten import HAS_LIBNATTEN # noqa: F401
from natten.utils.environment import (
_IS_CUDA_AVAILABLE,
_IS_XPU_AVAILABLE,
_IS_TORCH_COMPILE_SUPPORTED,
_TORCH_VERSION,
parse_env_flag,
Expand Down Expand Up @@ -52,6 +54,7 @@
__all__ = [
"HAS_LIBNATTEN",
"_IS_CUDA_AVAILABLE",
"_IS_XPU_AVAILABLE",
"_IS_TORCH_COMPILE_SUPPORTED",
"DISABLE_TQDM",
"_RUN_FLEX_TESTS",
Expand Down
14 changes: 11 additions & 3 deletions src/natten/backends/configs/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..._libnatten import HAS_LIBNATTEN
from ...context import is_flex_compile_allowed, is_flex_compile_backprop_allowed
from ...utils.checks import fmha_tensor_checks, log_or_raise_error, na_tensor_checks
from ...utils.device import get_device_cc, is_cpu, is_cuda, is_rocm
from ...utils.device import get_device_cc, is_cpu, is_cuda, is_rocm, is_xpu
from ...utils.dtype import is_fp8


Expand Down Expand Up @@ -597,6 +597,14 @@ def can_run_flex_attention(
target_fn("Can't run NATTEN with Flex Attention with torch < 2.7.")
return False

# XPU requires PyTorch 2.9+
if is_xpu(query.device) and _TORCH_VERSION < [2, 9]:
target_fn(
f"Can't run Flex Attention on XPU; requires PyTorch >= 2.9, "
f"got {_TORCH_VERSION[0]}.{_TORCH_VERSION[1]}."
)
return False

if torch_compile and not _FLEX_COMPILE_SUPPORTED:
target_fn("Can't run NATTEN with Flex Attention (compiled).)")
return False
Expand Down Expand Up @@ -657,9 +665,9 @@ def can_run_flex_attention(
return False

if not is_cuda(query.device):
if not is_cpu(query.device) and not is_rocm(query.device):
if not is_cpu(query.device) and not is_rocm(query.device) and not is_xpu(query.device):
target_fn(
"Can't run Flex Attention; tensor is not on a CUDA, ROCm, or CPU device: "
"Can't run Flex Attention; tensor is not on a CUDA, ROCm, XPU, or CPU device: "
f"{query.device.type}"
)

Expand Down
2 changes: 2 additions & 0 deletions src/natten/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def is_cuda(device: torch.device) -> bool:
def is_rocm(device: torch.device) -> bool:
return torch.cuda.is_available() and torch.version.hip and device.type == "cuda" # type: ignore

def is_xpu(device: torch.device) -> bool:
return torch.xpu.is_available() and torch.version.xpu and device.type == "xpu"

def is_cpu(device: torch.device) -> bool:
return device.type == "cpu"
Expand Down
1 change: 1 addition & 0 deletions src/natten/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def parse_env_str(env_var: str, default: str) -> str:


_IS_CUDA_AVAILABLE = torch.cuda.is_available()
_IS_XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False

_TORCH_VERSION = [int(x) for x in torch.__version__.split(".")[:2]]

Expand Down
16 changes: 11 additions & 5 deletions src/natten/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@

import torch

from .._environment import _IS_CUDA_AVAILABLE, _RUN_EXTENDED_TESTS, HAS_LIBNATTEN
from .._environment import _IS_CUDA_AVAILABLE, _IS_XPU_AVAILABLE, _RUN_EXTENDED_TESTS, HAS_LIBNATTEN

from ..backends.flex import _FLEX_COMPILE_SUPPORTED, _FLEX_SUPPORTED
from .device import get_device_cc, is_cuda
from .device import get_device_cc, is_cuda, is_xpu


def skip_if_libnatten_is_not_supported():
def decorator(f):
def wrapper(self, *args, **kwargs):
if not _IS_CUDA_AVAILABLE:
self.skipTest("CUDA is not available.")
if not _IS_CUDA_AVAILABLE and not _IS_XPU_AVAILABLE:
self.skipTest("CUDA or XPU is not available.")
elif not HAS_LIBNATTEN:
self.skipTest("Libnatten is not available.")
else:
Expand All @@ -59,7 +60,9 @@ def wrapper(self, *args, **kwargs):
def skip_if_flex_is_not_supported():
def decorator(f):
def wrapper(self, *args, **kwargs):
if not _FLEX_SUPPORTED or get_device_cc() < 70:
if not _FLEX_SUPPORTED:
self.skipTest("Flex backend is not supported.")
elif not _IS_XPU_AVAILABLE and get_device_cc() < 70:
self.skipTest("Flex backend is not supported.")
else:
return f(self, *args, **kwargs)
Expand Down Expand Up @@ -145,5 +148,8 @@ def supports_bfloat16(device: torch.device) -> bool:

return True

if is_xpu(device):
return True

# TODO:
return False
Loading