diff --git a/src/natten/_environment.py b/src/natten/_environment.py index 2477b876..b06e88b3 100644 --- a/src/natten/_environment.py +++ b/src/natten/_environment.py @@ -21,9 +21,11 @@ # ################################################################################################# + from natten._libnatten import HAS_LIBNATTEN # noqa: F401 from natten.utils.environment import ( _IS_CUDA_AVAILABLE, + _IS_XPU_AVAILABLE, _IS_TORCH_COMPILE_SUPPORTED, _TORCH_VERSION, parse_env_flag, @@ -52,6 +54,7 @@ __all__ = [ "HAS_LIBNATTEN", "_IS_CUDA_AVAILABLE", + "_IS_XPU_AVAILABLE", "_IS_TORCH_COMPILE_SUPPORTED", "DISABLE_TQDM", "_RUN_FLEX_TESTS", diff --git a/src/natten/backends/configs/checks.py b/src/natten/backends/configs/checks.py index e26b59c7..a8970e8c 100644 --- a/src/natten/backends/configs/checks.py +++ b/src/natten/backends/configs/checks.py @@ -30,7 +30,7 @@ from ..._libnatten import HAS_LIBNATTEN from ...context import is_flex_compile_allowed, is_flex_compile_backprop_allowed from ...utils.checks import fmha_tensor_checks, log_or_raise_error, na_tensor_checks -from ...utils.device import get_device_cc, is_cpu, is_cuda, is_rocm +from ...utils.device import get_device_cc, is_cpu, is_cuda, is_rocm, is_xpu from ...utils.dtype import is_fp8 @@ -597,6 +597,14 @@ def can_run_flex_attention( target_fn("Can't run NATTEN with Flex Attention with torch < 2.7.") return False + # XPU requires PyTorch 2.9+ + if is_xpu(query.device) and _TORCH_VERSION < [2, 9]: + target_fn( + f"Can't run Flex Attention on XPU; requires PyTorch >= 2.9, " + f"got {_TORCH_VERSION[0]}.{_TORCH_VERSION[1]}." + ) + return False + if torch_compile and not _FLEX_COMPILE_SUPPORTED: target_fn("Can't run NATTEN with Flex Attention (compiled).)") return False @@ -657,9 +665,9 @@ def can_run_flex_attention( return False if not is_cuda(query.device): - if not is_cpu(query.device) and not is_rocm(query.device): + if not is_cpu(query.device) and not is_rocm(query.device) and not is_xpu(query.device): target_fn( - "Can't run Flex Attention; tensor is not on a CUDA, ROCm, or CPU device: " + "Can't run Flex Attention; tensor is not on a CUDA, ROCm, XPU, or CPU device: " f"{query.device.type}" ) diff --git a/src/natten/utils/device.py b/src/natten/utils/device.py index a5e17124..b5e4b674 100644 --- a/src/natten/utils/device.py +++ b/src/natten/utils/device.py @@ -33,6 +33,8 @@ def is_cuda(device: torch.device) -> bool: def is_rocm(device: torch.device) -> bool: return torch.cuda.is_available() and torch.version.hip and device.type == "cuda" # type: ignore +def is_xpu(device: torch.device) -> bool: + return torch.xpu.is_available() and torch.version.xpu and device.type == "xpu" def is_cpu(device: torch.device) -> bool: return device.type == "cpu" diff --git a/src/natten/utils/environment.py b/src/natten/utils/environment.py index 60bba27e..c2653b82 100644 --- a/src/natten/utils/environment.py +++ b/src/natten/utils/environment.py @@ -55,6 +55,7 @@ def parse_env_str(env_var: str, default: str) -> str: _IS_CUDA_AVAILABLE = torch.cuda.is_available() +_IS_XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False _TORCH_VERSION = [int(x) for x in torch.__version__.split(".")[:2]] diff --git a/src/natten/utils/testing.py b/src/natten/utils/testing.py index 2f17dd6e..8eb99e1e 100644 --- a/src/natten/utils/testing.py +++ b/src/natten/utils/testing.py @@ -23,16 +23,17 @@ import torch -from .._environment import _IS_CUDA_AVAILABLE, _RUN_EXTENDED_TESTS, HAS_LIBNATTEN +from .._environment import _IS_CUDA_AVAILABLE, _IS_XPU_AVAILABLE, _RUN_EXTENDED_TESTS, HAS_LIBNATTEN + from ..backends.flex import _FLEX_COMPILE_SUPPORTED, _FLEX_SUPPORTED -from .device import get_device_cc, is_cuda +from .device import get_device_cc, is_cuda, is_xpu def skip_if_libnatten_is_not_supported(): def decorator(f): def wrapper(self, *args, **kwargs): - if not _IS_CUDA_AVAILABLE: - self.skipTest("CUDA is not available.") + if not _IS_CUDA_AVAILABLE and not _IS_XPU_AVAILABLE: + self.skipTest("CUDA or XPU is not available.") elif not HAS_LIBNATTEN: self.skipTest("Libnatten is not available.") else: @@ -59,7 +60,9 @@ def wrapper(self, *args, **kwargs): def skip_if_flex_is_not_supported(): def decorator(f): def wrapper(self, *args, **kwargs): - if not _FLEX_SUPPORTED or get_device_cc() < 70: + if not _FLEX_SUPPORTED: + self.skipTest("Flex backend is not supported.") + elif not _IS_XPU_AVAILABLE and get_device_cc() < 70: self.skipTest("Flex backend is not supported.") else: return f(self, *args, **kwargs) @@ -145,5 +148,8 @@ def supports_bfloat16(device: torch.device) -> bool: return True + if is_xpu(device): + return True + # TODO: return False diff --git a/tests/test_flex_xpu.py b/tests/test_flex_xpu.py new file mode 100644 index 00000000..23ba1017 --- /dev/null +++ b/tests/test_flex_xpu.py @@ -0,0 +1,618 @@ +################################################################################################# +# Copyright (c) 2022-2025 Ali Hassani. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################################# + +import math +import os +import random +import time +import unittest +from itertools import product + +import torch + +from natten import allow_flex_compile +from natten._environment import _RUN_FLEX_TESTS as RUN_FLEX_TESTS +from natten.backends.configs.flex import FLEX_FORWARD_TILE_SHAPES +from natten.utils import log +from natten.utils.testing import ( + skip_if_flex_compile_is_not_supported, + skip_if_flex_is_not_supported, + skip_if_not_running_extended_tests, + supports_bfloat16, + supports_float16, +) + +from .utils import NattenBackendTester, reset_torch_compile + + +logger = log.get_logger(__name__) + + +# TODO: enable when Flex is stable / check with new PT releases +ENABLE_FLEX_COMPILE_TESTS = False +ENABLE_FLEX_COMPILE_BACKPROP_TESTS = False + + +def _reset_everything(): + # NOTE: It is important to ensure determinism in torch GEMMs since + # we don't write our own. + # PT's caching allocator should also be turned off in unit tests for + # when we run memcheck. + torch.use_deterministic_algorithms(True) + torch.manual_seed(42) + torch.xpu.empty_cache() + + reset_torch_compile(1024) + + allow_flex_compile( + ENABLE_FLEX_COMPILE_TESTS, backprop=ENABLE_FLEX_COMPILE_BACKPROP_TESTS + ) + + +@unittest.skipIf(not RUN_FLEX_TESTS, "Flex tests are disabled by environment variable") +class FlexBackendTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_all_dtypes_against_flex_cpu_2x_fna( + self, + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + is_causal=None, + additional_kv_length=0, + torch_compile=False, + constrain_torch_compile_cache=True, + max_runs=None, + ): + torch.set_default_device("xpu") + assert isinstance(input_shape, tuple) + na_dim = len(input_shape) + assert na_dim in [1, 2, 3], "Only supports NA1D, 2D, 3D." + + test_backprop = ENABLE_FLEX_COMPILE_BACKPROP_TESTS if torch_compile else True + + if additional_kv_length > 0: + + reset_torch_compile(1) + + tester = NattenBackendTester( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + test_backprop=test_backprop, + reference_backend="flex-fna", + reference_fmha_backend="flex-fmha", + dtype=torch.float32, + device="cpu", + target_device="xpu", + ) + + # TODO: write note on why backprop eps is different when additional_kv_length > 0 + run_idx = 0 + no_token_permute_config = (None, None) + configs = FLEX_FORWARD_TILE_SHAPES[na_dim] + [no_token_permute_config] + for q_tile_shape, kv_tile_shape in configs: + if constrain_torch_compile_cache: + if torch_compile or additional_kv_length > 0: + reset_torch_compile(4 if additional_kv_length > 0 else 2) + else: + reset_torch_compile(0) + + assert supports_float16( + torch.get_default_device() + ), "Flex only supports SM70 and above, and it should have FP16!" + + tester.test( + eps=(1e-2, 1e-2 if additional_kv_length == 0 else 3e-1), + dtype=torch.float16, + target_backend="flex-fna", + target_fmha_backend="flex-fmha", + q_tile_shape=q_tile_shape, + kv_tile_shape=kv_tile_shape, + torch_compile=torch_compile, + ) + run_idx += 1 + if max_runs is not None and run_idx > max_runs: + return + + if supports_bfloat16(torch.get_default_device()): + tester.test( + eps=(1e-1, 1e-1 if additional_kv_length == 0 else 5e-1), + dtype=torch.bfloat16, + target_backend="flex-fna", + target_fmha_backend="flex-fmha", + q_tile_shape=q_tile_shape, + kv_tile_shape=kv_tile_shape, + torch_compile=torch_compile, + ) + run_idx += 1 + if max_runs is not None and run_idx > max_runs: + return + + @skip_if_flex_compile_is_not_supported() + def test_0_compile_caching(self): + if not ENABLE_FLEX_COMPILE_TESTS: + self.skipTest("Flex compile tests have been disabled.") + + # Verify torch compile caching works by rerunning use cases a second time, + # and checking runtimes. + # Might be a little flaky... + # Torch compile autotuner might also have a separate cache of its own... + + def run_tests(problem_sizes, max_runs_per_use_case): + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=False, + additional_kv_length=0, + torch_compile=True, + constrain_torch_compile_cache=False, + max_runs=max_runs_per_use_case, + ) + + max_runs_per_use_case = 10 + problem_sizes = [ + (1, 1, 32, (128,), (3,), (2,), (5)), + (1, 1, 32, (128,), (3,), (1,), (1)), + (1, 1, 32, (4, 4, 4), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + (1, 1, 128, (16, 16), (4, 4), (2, 3), (1, 1)), + (4, 3, 32, (32, 32), (31, 31), (15, 15), (1, 1)), + (1, 1, 32, (32,), (32,), (1,), (1)), + ] + + reset_torch_compile(len(problem_sizes) * max_runs_per_use_case) + MAX_RUNTIME_S = 5 + # I.e. second run should take at most 10% of the first run + EXPECTED_PCT_OF_RUNTIME = 0.1 + + logger.debug("Testing torch compile cache.") + start_time = time.time() + run_tests(problem_sizes, max_runs_per_use_case) + elapsed_first_run = time.time() - start_time + logger.debug( + f"First run of {len(problem_sizes)} use cases finished in {elapsed_first_run:.1f} seconds." + ) + + logger.debug("Second run.") + start_time = time.time() + run_tests(problem_sizes, max_runs_per_use_case) + elapsed_second_run = time.time() - start_time + logger.debug( + f"Second run of {len(problem_sizes)} use cases finished in {elapsed_second_run:.1f} seconds." + ) + + assert elapsed_second_run <= min( + elapsed_first_run * EXPECTED_PCT_OF_RUNTIME, MAX_RUNTIME_S + ) + + @skip_if_flex_is_not_supported() + def test_1d_against_flex_cpu_2x(self): + problem_sizes = [ + (1, 1, 8, (128,), (3,), (2,), (5)), + (1, 1, 16, (128,), (8,), (7,), (5)), + (1, 1, 32, (125,), (3,), (1,), (1)), + (1, 2, 8, (125,), (15,), (1,), (1)), + (1, 1, 64, (256,), (3,), (2,), (10)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal in [True, False]: + is_causal = (causal,) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=False, + ) + + @skip_if_flex_compile_is_not_supported() + def test_1d_against_flex_cpu_2x_compiled(self): + if not ENABLE_FLEX_COMPILE_TESTS: + self.skipTest("Flex compile tests have been disabled.") + + problem_sizes = [ + (1, 2, 32, (128,), (15,), (1,), (1)), + (1, 1, 32, (128,), (8,), (7,), (5)), + (2, 4, 128, (128,), (63,), (31,), (1)), + (4, 3, 128, (256,), (255,), (82,), (1)), + (1, 1, 128, (128,), (61,), (33,), (1)), + (1, 1, 32, (256,), (3,), (2,), (10)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal in [True, False]: + is_causal = (causal,) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=True, + ) + + @skip_if_flex_is_not_supported() + def test_2d_against_flex_cpu_2x(self): + problem_sizes = [ + (1, 1, 8, (84, 69), (7, 20), (1, 6), (5, 1)), + (1, 1, 32, (19, 29), (8, 8), (1, 1), (2, 3)), + (1, 1, 16, (128, 128), (19, 24), (2, 3), (2, 2)), + (1, 1, 64, (56, 56), (17, 4), (2, 1), (3, 2)), + (2, 2, 64, (32, 64), (25, 31), (10, 20), (1, 2)), + (2, 4, 64, (64, 128), (21, 29), (10, 12), (3, 4)), + (4, 3, 128, (56, 56), (7, 7), (1, 1), (2, 4)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal_x, causal_y in product([True, False], [True, False]): + is_causal = (causal_x, causal_y) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=False, + ) + + @skip_if_not_running_extended_tests() + @skip_if_flex_is_not_supported() + def test_2d_against_flex_cpu_2x_extended(self): + problem_sizes = [ + (1, 1, 32, (84, 69), (7, 68), (1, 1), (5, 1)), + (1, 1, 32, (84, 69), (7, 23), (1, 1), (5, 1)), + (1, 1, 32, (128, 128), (4, 4), (1, 1), (2, 2)), + (1, 1, 32, (128, 128), (4, 4), (1, 1), (1, 1)), + (1, 1, 64, (56, 56), (17, 4), (2, 1), (1, 2)), + (1, 1, 64, (128, 128), (93, 78), (2, 2), (1, 1)), + (1, 1, 128, (16, 16), (3, 3), (2, 2), (1, 1)), + (1, 1, 128, (16, 16), (4, 4), (2, 3), (1, 1)), + (4, 3, 32, (32, 32), (31, 31), (15, 15), (1, 1)), + (2, 2, 64, (32, 64), (26, 30), (1, 1), (1, 2)), + (2, 4, 64, (64, 128), (55, 101), (1, 1), (1, 1)), + (4, 3, 128, (28, 46), (11, 13), (1, 1), (1, 1)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal_x, causal_y in product([True, False], [True, False]): + is_causal = (causal_x, causal_y) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=False, + ) + + @skip_if_flex_compile_is_not_supported() + def test_2d_against_flex_cpu_2x_compiled(self): + if not ENABLE_FLEX_COMPILE_TESTS: + self.skipTest("Flex compile tests have been disabled.") + + problem_sizes = [ + (1, 1, 32, (84, 69), (7, 68), (1, 1), (5, 1)), + (1, 1, 128, (19, 29), (8, 8), (1, 1), (2, 3)), + (1, 1, 32, (128, 128), (4, 4), (1, 1), (2, 2)), + (1, 1, 64, (56, 56), (17, 4), (2, 1), (3, 2)), + (1, 1, 64, (128, 128), (93, 78), (2, 2), (1, 1)), + (1, 1, 128, (16, 16), (3, 3), (2, 2), (1, 1)), + (2, 2, 64, (32, 64), (25, 31), (10, 20), (1, 2)), + (2, 4, 64, (64, 128), (55, 101), (1, 1), (1, 1)), + (2, 4, 64, (64, 128), (21, 29), (10, 12), (3, 4)), + (4, 3, 128, (56, 56), (7, 7), (1, 1), (2, 4)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for causal_x, causal_y in product([True, False], [True, False]): + is_causal = (causal_x, causal_y) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=0, + torch_compile=True, + ) + + @skip_if_flex_compile_is_not_supported() + def test_3d_against_flex_cpu_2x_compiled(self): + if not ENABLE_FLEX_COMPILE_TESTS: + self.skipTest("Flex compile tests have been disabled.") + + problem_sizes = [ + (1, 1, 128, (24, 44, 80), (24, 12, 24), (1, 4, 8), (1, 1, 1)), + (1, 1, 128, (20, 40, 40), (20, 12, 16), (1, 4, 16), (1, 1, 1)), + (1, 1, 64, (18, 37, 12), (14, 16, 12), (12, 8, 1), (1, 2, 1)), + (1, 1, 128, (8, 8, 4), (3, 4, 3), (1, 1, 1), (1, 1, 1)), + (1, 2, 128, (8, 8, 12), (5, 8, 11), (2, 3, 4), (1, 1, 1)), + (4, 8, 64, (32, 10, 10), (7, 3, 3), (5, 1, 1), (1, 2, 3)), + (1, 4, 32, (8, 8, 16), (3, 3, 3), (2, 1, 2), (2, 2, 4)), + (2, 2, 32, (8, 8, 10), (3, 4, 3), (3, 4, 1), (1, 1, 1)), + (1, 12, 64, (32, 8, 8), (7, 5, 5), (2, 1, 3), (2, 1, 1)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for causal_x, causal_y, causal_z in product( + [True, False], [True, False], [True, False] + ): + is_causal = (causal_x, causal_y, causal_z) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=0, + torch_compile=True, + ) + + + @skip_if_flex_is_not_supported() + def test_3d_against_flex_cpu_2x(self): + problem_sizes = [ + (1, 1, 8, (13, 11, 9), (3, 4, 3), (2, 3, 3), (3, 2, 2)), + (1, 1, 16, (13, 11, 9), (3, 4, 3), (1, 1, 1), (3, 2, 2)), + (1, 2, 128, (8, 8, 12), (5, 8, 11), (2, 3, 4), (1, 1, 1)), + (1, 1, 64, (32, 10, 10), (7, 3, 3), (5, 1, 1), (1, 2, 3)), + (1, 4, 32, (8, 8, 16), (3, 3, 3), (2, 1, 2), (2, 2, 4)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal_x, causal_y, causal_z in product( + [True, False], [True, False], [True, False] + ): + is_causal = (causal_x, causal_y, causal_z) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=False, + ) + + @skip_if_not_running_extended_tests() + @skip_if_flex_is_not_supported() + def test_3d_against_flex_cpu_2x_extended(self): + problem_sizes = [ + (1, 1, 32, (4, 4, 4), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + (1, 2, 32, (4, 4, 4), (2, 2, 2), (1, 1, 1), (1, 1, 1)), + (1, 2, 32, (4, 4, 4), (2, 2, 2), (2, 2, 1), (1, 1, 1)), + (1, 1, 128, (8, 8, 4), (3, 4, 3), (1, 1, 1), (1, 1, 1)), + (1, 2, 128, (8, 8, 12), (5, 8, 11), (2, 3, 4), (1, 1, 1)), + (2, 2, 32, (8, 8, 10), (3, 4, 3), (3, 4, 1), (1, 1, 1)), + (1, 12, 64, (32, 8, 8), (7, 5, 5), (2, 1, 3), (2, 1, 1)), + (4, 8, 64, (32, 10, 10), (7, 3, 3), (5, 1, 1), (1, 2, 3)), + (1, 1, 64, (18, 37, 12), (14, 16, 12), (12, 8, 1), (1, 2, 1)), + ] + for ( + batch, + heads, + head_dim, + input_shape, + kernel_size, + stride, + dilation, + ) in problem_sizes: + for additional_kv_length in [0, 64]: + for causal_x, causal_y, causal_z in product( + [True, False], [True, False], [True, False] + ): + is_causal = (causal_x, causal_y, causal_z) + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + torch_compile=False, + ) + + def _test_rand_sweep_against_flex_cpu_2x(self, na_dim, torch_compile: bool = False): + random.seed(42) + + max_tests = 1000 + max_seqlen = 2**17 if torch_compile else 2**13 + max_kernel_size = None if torch_compile else 2**10 + + for i in range(max_tests): + batch = random.choice(range(1, 4)) + heads = random.choice(range(1, 4)) + + head_dim_choices = [32, 64, 128] if torch_compile else [8, 16, 32, 64, 128] + head_dim = random.choice(head_dim_choices) + + input_shape_ = [] + for j in range(na_dim): + input_shape_.append(random.choice(range(4, 97))) + + while math.prod(input_shape_) > max_seqlen: + dim_to_cut = random.choice(range(na_dim)) + input_shape_[dim_to_cut] = max(4, int(input_shape_[dim_to_cut] * 0.1)) + + input_shape = tuple(input_shape_) + + kernel_size_ = [random.choice(range(2, x + 1)) for x in input_shape] + if max_kernel_size is not None: + while math.prod(kernel_size_) > max_kernel_size: + dim_to_cut = random.choice(range(na_dim)) + kernel_size_[dim_to_cut] = max( + 2, int(kernel_size_[dim_to_cut] * 0.1) + ) + + kernel_size = tuple(kernel_size_) + stride = tuple(random.choice(range(1, k + 1)) for k in kernel_size) + dilation = tuple( + random.choice(range(1, x // k + 1)) + for x, k in zip(input_shape, kernel_size) + ) + is_causal = tuple(random.choice([False, True]) for _ in range(na_dim)) + + self._test_all_dtypes_against_flex_cpu_2x_fna( + batch=batch, + heads=heads, + head_dim=head_dim, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=0, + torch_compile=torch_compile, + ) + + @skip_if_not_running_extended_tests() + @skip_if_flex_is_not_supported() + def test_rand_sweep_1d_against_flex_cpu_2x(self): + self._test_rand_sweep_against_flex_cpu_2x(1) + + @skip_if_not_running_extended_tests() + @skip_if_flex_is_not_supported() + def test_rand_sweep_2d_against_flex_cpu_2x(self): + self._test_rand_sweep_against_flex_cpu_2x(2) + + @skip_if_not_running_extended_tests() + @skip_if_flex_is_not_supported() + def test_rand_sweep_3d_against_flex_cpu_2x(self): + self._test_rand_sweep_against_flex_cpu_2x(3) + + +if __name__ == "__main__": + torch.manual_seed(42) + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 32b36f34..d31ab6bb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,6 +35,13 @@ logger = log.get_logger(__name__) +def _synchronize(device: str): + if device == "cuda": + torch.cuda.synchronize() + elif device == "xpu": + torch.xpu.synchronize() + + def reset_torch_compile(cache_size_limit): # Torch compile reset and sensible settings for unit testing logger.debug( @@ -65,6 +72,8 @@ def __init__( reference_fmha_backend: str, dtype: torch.dtype, head_dim_v: Optional[int] = None, + device: str = "cuda", + target_device: Optional[str] = None, heads_kv: Optional[int] = None, ): assert isinstance(input_shape, tuple) @@ -85,6 +94,7 @@ def __init__( self.test_backprop = test_backprop self.reference_backend = reference_backend self.reference_fmha_backend = reference_fmha_backend + self.target_device = target_device or device with torch.no_grad(): orig_dtype = dtype @@ -94,22 +104,22 @@ def __init__( q_ref, k_ref, v_ref, d_out_ref = ( torch.randn( (self.batch, *self.input_shape, self.heads, self.head_dim), - device="cuda", + device=device, dtype=dtype, ), torch.randn( (self.batch, *self.input_shape, self.heads_kv, self.head_dim), - device="cuda", + device=device, dtype=dtype, ), torch.randn( (self.batch, *self.input_shape, self.heads_kv, self.head_dim_v), - device="cuda", + device=device, dtype=dtype, ), torch.randn( (self.batch, *self.input_shape, self.heads, self.head_dim_v), - device="cuda", + device=device, dtype=dtype, ) * 0.05, @@ -138,7 +148,7 @@ def __init__( self.heads_kv, self.head_dim, ), - device="cuda", + device=device, dtype=dtype, ) additional_v_ref = torch.randn( @@ -148,7 +158,7 @@ def __init__( self.heads_kv, self.head_dim_v, ), - device="cuda", + device=device, dtype=dtype, ) @@ -162,7 +172,7 @@ def __init__( self.additional_v = additional_v_ref.clone() # Reference - torch.cuda.synchronize() + _synchronize(device) start_time = time.time() q_ref.requires_grad_(True) @@ -229,7 +239,7 @@ def __init__( self.d_additional_k_ref = additional_k_ref.grad.clone().float() self.d_additional_v_ref = additional_v_ref.grad.clone().float() - torch.cuda.synchronize() + _synchronize(device) reference_time = time.time() - start_time logger.debug( f"Reference ({reference_backend}/{reference_fmha_backend}) ran in {reference_time:.2f} seconds." @@ -285,10 +295,10 @@ def test( ) q, k, v, d_out = ( - self.q.clone().to(dtype), - self.k.clone().to(dtype), - self.v.clone().to(dtype), - self.d_out.clone().to(dtype), + self.q.clone().to(device=self.target_device, dtype=dtype), + self.k.clone().to(device=self.target_device, dtype=dtype), + self.v.clone().to(device=self.target_device, dtype=dtype), + self.d_out.clone().to(device=self.target_device, dtype=dtype), ) q.requires_grad_(test_backprop_safe) k.requires_grad_(test_backprop_safe) @@ -299,13 +309,13 @@ def test( if additional_kv_length > 0: assert self.additional_k is not None assert self.additional_v is not None - additional_k = self.additional_k.clone().to(dtype) - additional_v = self.additional_v.clone().to(dtype) + additional_k = self.additional_k.clone().to(device=self.target_device, dtype=dtype) + additional_v = self.additional_v.clone().to(device=self.target_device, dtype=dtype) additional_k = additional_k.requires_grad_(test_backprop_safe) additional_v = additional_v.requires_grad_(test_backprop_safe) - torch.cuda.synchronize() + _synchronize(self.target_device) start_time = time.time() out_: torch.Tensor = ( @@ -361,13 +371,13 @@ def test( else: eps_forward, eps_backward = eps, eps - torch.cuda.synchronize() + _synchronize(self.target_device) runtime = time.time() - start_time logger.debug( f"Backend ({target_backend}/{target_fmha_backend}) ran in {runtime:.2f} seconds." ) - torch.testing.assert_close(out, self.out_ref, atol=eps_forward, rtol=0) + torch.testing.assert_close(out.to(self.out_ref.device), self.out_ref, atol=eps_forward, rtol=0) if test_backprop_safe: if isinstance(eps_backward, tuple): @@ -377,15 +387,15 @@ def test( assert isinstance(eps_backward, float) eps_dq, eps_dk, eps_dv = eps_backward, eps_backward, eps_backward - torch.testing.assert_close(dq, self.dq_ref, atol=eps_dq, rtol=0) - torch.testing.assert_close(dk, self.dk_ref, atol=eps_dk, rtol=0) - torch.testing.assert_close(dv, self.dv_ref, atol=eps_dv, rtol=0) + torch.testing.assert_close(dq.to(self.dq_ref.device), self.dq_ref, atol=eps_dq, rtol=0) + torch.testing.assert_close(dk.to(self.dk_ref.device), self.dk_ref, atol=eps_dk, rtol=0) + torch.testing.assert_close(dv.to(self.dv_ref.device), self.dv_ref, atol=eps_dv, rtol=0) if additional_kv_length > 0: torch.testing.assert_close( - d_additional_k, self.d_additional_k_ref, atol=eps_dk, rtol=0 + d_additional_k.to(self.d_additional_k_ref.device), self.d_additional_k_ref, atol=eps_dk, rtol=0 ) torch.testing.assert_close( - d_additional_v, self.d_additional_v_ref, atol=eps_dv, rtol=0 + d_additional_v.to(self.d_additional_v_ref.device), self.d_additional_v_ref, atol=eps_dv, rtol=0 )