diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 56c166f0..73755d17 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -17,7 +17,7 @@ from kernels._system import glibc_version from kernels._versions import select_revision_or_version -from kernels.backends import _backend +from kernels.backends import _backend, _select_backend from kernels.compat import has_torch, has_tvm_ffi from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock @@ -174,6 +174,12 @@ def _find_kernel_in_repo_path( assert variant_path is not None + exact_backend_variant = _select_backend(backend).variant + if exact_backend_variant not in variant: + logging.info( + f"Exact build variant matching {exact_backend_variant} not found, resolved to {variant}" + ) + if variant_locks is not None: variant_lock = variant_locks.get(variant) if variant_lock is None: diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 33f6ee76..24d95541 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -3,7 +3,7 @@ from packaging.version import parse -from kernels.backends import _select_backend +from kernels.backends import CUDA, Backend, _select_backend from kernels.compat import has_torch, has_tvm_ffi BUILD_VARIANT_REGEX = re.compile( @@ -11,13 +11,22 @@ ) +def _compatible_backend_variants(backend: Backend) -> list[str]: + if isinstance(backend, CUDA): + return [ + f"cu{backend.version.major}{minor}" + for minor in range(backend.version.minor, -1, -1) + ] + return [backend.variant] + + def _torch_build_variant(backend: str | None) -> list[str]: if not has_torch: return [] selected_backend = _select_backend(backend) - backend_variant = selected_backend.variant + backend_variants = _compatible_backend_variants(selected_backend) import torch @@ -28,17 +37,20 @@ def _torch_build_variant(backend: str | None) -> list[str]: if os == "darwin": cpu = "aarch64" if cpu == "arm64" else cpu return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" + f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}" + for v in backend_variants ] elif os == "windows": cpu = "x86_64" if cpu == "AMD64" else cpu return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" + f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}" + for v in backend_variants ] cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" return [ - f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{backend_variant}-{cpu}-{os}" + f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{v}-{cpu}-{os}" + for v in backend_variants ] @@ -48,7 +60,7 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]: selected_backend = _select_backend(backend) - backend_variant = selected_backend.variant + backend_variants = _compatible_backend_variants(selected_backend) import tvm_ffi @@ -57,7 +69,8 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]: os = platform.system().lower() return [ - f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{backend_variant}-{cpu}-{os}" + f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{v}-{cpu}-{os}" + for v in backend_variants ] diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py new file mode 100644 index 00000000..f4667391 --- /dev/null +++ b/kernels/tests/test_variants.py @@ -0,0 +1,131 @@ +import logging +from unittest.mock import patch + +from packaging.version import Version + +from kernels.backends import CANN, CPU, CUDA, XPU, Metal, Neuron, ROCm +from kernels.utils import _find_kernel_in_repo_path +from kernels.variants import _compatible_backend_variants + + +class TestCompatibleBackendVariants: + # Only CUDA for now. + + def test_cuda_single_variant_minor_zero(self): + backend = CUDA(version=Version("12.0")) + assert _compatible_backend_variants(backend) == ["cu120"] + + def test_cuda_returns_descending_minor_versions(self): + backend = CUDA(version=Version("12.9")) + variants = _compatible_backend_variants(backend) + assert variants == [f"cu12{i}" for i in range(9, -1, -1)] + + def test_cuda_first_variant_is_exact_match(self): + backend = CUDA(version=Version("12.6")) + variants = _compatible_backend_variants(backend) + assert variants[0] == "cu126" + + def test_cuda_last_variant_is_major_minor_zero(self): + backend = CUDA(version=Version("12.6")) + variants = _compatible_backend_variants(backend) + assert variants[-1] == "cu120" + + def test_cuda_resolves_to_largest_available_minor(self): + # System has CUDA 12.9; available builds: 12.6, 12.8, 13.0. + # Expected: resolve to 12.8 (largest z <= 9 that is available). + backend = CUDA(version=Version("12.9")) + candidates = _compatible_backend_variants(backend) + available = {"cu126", "cu128", "cu130"} + resolved = next((v for v in candidates if v in available), None) + assert resolved == "cu128" + + def test_cuda_falls_back_when_exact_not_available(self): + # System has CUDA 12.9; only 12.6 available for major 12. + backend = CUDA(version=Version("12.9")) + candidates = _compatible_backend_variants(backend) + available = {"cu126"} + resolved = next((v for v in candidates if v in available), None) + assert resolved == "cu126" + + def test_cuda_no_match_when_only_higher_minor_available(self): + # System has CUDA 12.6; only 12.8 available — should not match. + backend = CUDA(version=Version("12.6")) + candidates = _compatible_backend_variants(backend) + available = {"cu128"} + resolved = next((v for v in candidates if v in available), None) + assert resolved is None + + def test_cuda_no_match_for_different_major(self): + # System has CUDA 12.9; only 13.x builds available — should not match. + backend = CUDA(version=Version("12.9")) + candidates = _compatible_backend_variants(backend) + available = {"cu130", "cu131"} + resolved = next((v for v in candidates if v in available), None) + assert resolved is None + + def test_cuda_major_version_preserved(self): + backend = CUDA(version=Version("11.8")) + variants = _compatible_backend_variants(backend) + assert all(v.startswith("cu11") for v in variants) + + # Non-CUDA + + def test_cpu_returns_single_variant(self): + assert _compatible_backend_variants(CPU()) == ["cpu"] + + def test_metal_returns_single_variant(self): + assert _compatible_backend_variants(Metal()) == ["metal"] + + def test_neuron_returns_single_variant(self): + assert _compatible_backend_variants(Neuron()) == ["neuron"] + + def test_rocm_returns_single_variant(self): + backend = ROCm(version=Version("6.2")) + assert _compatible_backend_variants(backend) == ["rocm62"] + + def test_xpu_returns_single_variant(self): + backend = XPU(version=Version("2024.2")) + assert _compatible_backend_variants(backend) == ["xpu20242"] + + def test_cann_returns_single_variant(self): + backend = CANN(version=Version("8.0")) + assert _compatible_backend_variants(backend) == ["cann80"] + + +class TestVariantResolutionLogging: + def test_logs_when_resolved_to_lower_minor(self, tmp_path, caplog): + exact = "torch28-cxx11-cu129-x86_64-linux" + resolved = "torch28-cxx11-cu128-x86_64-linux" + + build_dir = tmp_path / "build" / resolved + build_dir.mkdir(parents=True) + (build_dir / "__init__.py").touch() + + cuda_129 = CUDA(version=Version("12.9")) + with ( + patch("kernels.utils._build_variants", return_value=[exact, resolved]), + patch("kernels.utils._select_backend", return_value=cuda_129), + caplog.at_level(logging.INFO), + ): + _find_kernel_in_repo_path(tmp_path, "test_kernel") + + assert any( + "cu129" in r.message and resolved in r.message for r in caplog.records + ) + + def test_no_log_when_exact_match(self, tmp_path, caplog): + exact = "torch28-cxx11-cu129-x86_64-linux" + + build_dir = tmp_path / "build" / exact + build_dir.mkdir(parents=True) + (build_dir / "__init__.py").touch() + + cuda_129 = CUDA(version=Version("12.9")) + with ( + patch("kernels.utils._build_variants", return_value=[exact]), + patch("kernels.utils._select_backend", return_value=cuda_129), + caplog.at_level(logging.INFO), + ): + _find_kernel_in_repo_path(tmp_path, "test_kernel") + + assert not any("resolved to" in r.message for r in caplog.records)