Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from kernels._system import glibc_version
from kernels._versions import select_revision_or_version
from kernels.backends import _backend
from kernels.backends import _backend, _select_backend
from kernels.compat import has_torch, has_tvm_ffi
from kernels.deps import validate_dependencies
from kernels.lockfile import KernelLock, VariantLock
Expand Down Expand Up @@ -174,6 +174,12 @@ def _find_kernel_in_repo_path(

assert variant_path is not None

exact_backend_variant = _select_backend(backend).variant
if exact_backend_variant not in variant:
logging.info(
f"Exact build variant matching {exact_backend_variant} not found, resolved to {variant}"
)

if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
Expand Down
27 changes: 20 additions & 7 deletions kernels/src/kernels/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,30 @@

from packaging.version import parse

from kernels.backends import _select_backend
from kernels.backends import CUDA, Backend, _select_backend
from kernels.compat import has_torch, has_tvm_ffi

BUILD_VARIANT_REGEX = re.compile(
r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)"
)


def _compatible_backend_variants(backend: Backend) -> list[str]:
if isinstance(backend, CUDA):
return [
f"cu{backend.version.major}{minor}"
for minor in range(backend.version.minor, -1, -1)
]
return [backend.variant]


def _torch_build_variant(backend: str | None) -> list[str]:
if not has_torch:
return []

selected_backend = _select_backend(backend)

backend_variant = selected_backend.variant
backend_variants = _compatible_backend_variants(selected_backend)

import torch

Expand All @@ -28,17 +37,20 @@ def _torch_build_variant(backend: str | None) -> list[str]:
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return [
f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]
elif os == "windows":
cpu = "x86_64" if cpu == "AMD64" else cpu
return [
f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]

cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return [
f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{v}-{cpu}-{os}"
for v in backend_variants
]


Expand All @@ -48,7 +60,7 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]:

selected_backend = _select_backend(backend)

backend_variant = selected_backend.variant
backend_variants = _compatible_backend_variants(selected_backend)

import tvm_ffi

Expand All @@ -57,7 +69,8 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]:
os = platform.system().lower()

return [
f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{backend_variant}-{cpu}-{os}"
f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]


Expand Down
131 changes: 131 additions & 0 deletions kernels/tests/test_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
from unittest.mock import patch

from packaging.version import Version

from kernels.backends import CANN, CPU, CUDA, XPU, Metal, Neuron, ROCm
from kernels.utils import _find_kernel_in_repo_path
from kernels.variants import _compatible_backend_variants


class TestCompatibleBackendVariants:
# Only CUDA for now.

def test_cuda_single_variant_minor_zero(self):
backend = CUDA(version=Version("12.0"))
assert _compatible_backend_variants(backend) == ["cu120"]

def test_cuda_returns_descending_minor_versions(self):
backend = CUDA(version=Version("12.9"))
variants = _compatible_backend_variants(backend)
assert variants == [f"cu12{i}" for i in range(9, -1, -1)]

def test_cuda_first_variant_is_exact_match(self):
backend = CUDA(version=Version("12.6"))
variants = _compatible_backend_variants(backend)
assert variants[0] == "cu126"

def test_cuda_last_variant_is_major_minor_zero(self):
backend = CUDA(version=Version("12.6"))
variants = _compatible_backend_variants(backend)
assert variants[-1] == "cu120"

def test_cuda_resolves_to_largest_available_minor(self):
# System has CUDA 12.9; available builds: 12.6, 12.8, 13.0.
# Expected: resolve to 12.8 (largest z <= 9 that is available).
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu126", "cu128", "cu130"}
resolved = next((v for v in candidates if v in available), None)
assert resolved == "cu128"

def test_cuda_falls_back_when_exact_not_available(self):
# System has CUDA 12.9; only 12.6 available for major 12.
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu126"}
resolved = next((v for v in candidates if v in available), None)
assert resolved == "cu126"

def test_cuda_no_match_when_only_higher_minor_available(self):
# System has CUDA 12.6; only 12.8 available — should not match.
backend = CUDA(version=Version("12.6"))
candidates = _compatible_backend_variants(backend)
available = {"cu128"}
resolved = next((v for v in candidates if v in available), None)
assert resolved is None

def test_cuda_no_match_for_different_major(self):
# System has CUDA 12.9; only 13.x builds available — should not match.
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu130", "cu131"}
resolved = next((v for v in candidates if v in available), None)
assert resolved is None

def test_cuda_major_version_preserved(self):
backend = CUDA(version=Version("11.8"))
variants = _compatible_backend_variants(backend)
assert all(v.startswith("cu11") for v in variants)

# Non-CUDA

def test_cpu_returns_single_variant(self):
assert _compatible_backend_variants(CPU()) == ["cpu"]

def test_metal_returns_single_variant(self):
assert _compatible_backend_variants(Metal()) == ["metal"]

def test_neuron_returns_single_variant(self):
assert _compatible_backend_variants(Neuron()) == ["neuron"]

def test_rocm_returns_single_variant(self):
backend = ROCm(version=Version("6.2"))
assert _compatible_backend_variants(backend) == ["rocm62"]

def test_xpu_returns_single_variant(self):
backend = XPU(version=Version("2024.2"))
assert _compatible_backend_variants(backend) == ["xpu20242"]

def test_cann_returns_single_variant(self):
backend = CANN(version=Version("8.0"))
assert _compatible_backend_variants(backend) == ["cann80"]


class TestVariantResolutionLogging:
def test_logs_when_resolved_to_lower_minor(self, tmp_path, caplog):
exact = "torch28-cxx11-cu129-x86_64-linux"
resolved = "torch28-cxx11-cu128-x86_64-linux"

build_dir = tmp_path / "build" / resolved
build_dir.mkdir(parents=True)
(build_dir / "__init__.py").touch()

cuda_129 = CUDA(version=Version("12.9"))
with (
patch("kernels.utils._build_variants", return_value=[exact, resolved]),
patch("kernels.utils._select_backend", return_value=cuda_129),
caplog.at_level(logging.INFO),
):
_find_kernel_in_repo_path(tmp_path, "test_kernel")

assert any(
"cu129" in r.message and resolved in r.message for r in caplog.records
)

def test_no_log_when_exact_match(self, tmp_path, caplog):
exact = "torch28-cxx11-cu129-x86_64-linux"

build_dir = tmp_path / "build" / exact
build_dir.mkdir(parents=True)
(build_dir / "__init__.py").touch()

cuda_129 = CUDA(version=Version("12.9"))
with (
patch("kernels.utils._build_variants", return_value=[exact]),
patch("kernels.utils._select_backend", return_value=cuda_129),
caplog.at_level(logging.INFO),
):
_find_kernel_in_repo_path(tmp_path, "test_kernel")

assert not any("resolved to" in r.message for r in caplog.records)
Loading