Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from kernels._system import glibc_version
from kernels._versions import select_revision_or_version
from kernels.backends import _backend
from kernels.backends import _backend, _select_backend
from kernels.compat import has_torch, has_tvm_ffi
from kernels.deps import validate_dependencies
from kernels.lockfile import KernelLock, VariantLock
Expand Down Expand Up @@ -174,6 +174,12 @@ def _find_kernel_in_repo_path(

assert variant_path is not None

exact_backend_variant = _select_backend(backend).variant
if exact_backend_variant not in variant:
logging.info(
f"Exact build variant matching {exact_backend_variant} not found, resolved to {variant}"
)

if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
Expand Down
27 changes: 20 additions & 7 deletions kernels/src/kernels/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,30 @@

from packaging.version import parse

from kernels.backends import _select_backend
from kernels.backends import CUDA, Backend, _select_backend
from kernels.compat import has_torch, has_tvm_ffi

BUILD_VARIANT_REGEX = re.compile(
r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)"
)


def _compatible_backend_variants(backend: Backend) -> list[str]:
if isinstance(backend, CUDA):
return [
f"cu{backend.version.major}{minor}"
for minor in range(backend.version.minor, -1, -1)
]
return [backend.variant]


def _torch_build_variant(backend: str | None) -> list[str]:
if not has_torch:
return []

selected_backend = _select_backend(backend)

backend_variant = selected_backend.variant
backend_variants = _compatible_backend_variants(selected_backend)

import torch

Expand All @@ -28,17 +37,20 @@ def _torch_build_variant(backend: str | None) -> list[str]:
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return [
f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]
elif os == "windows":
cpu = "x86_64" if cpu == "AMD64" else cpu
return [
f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]

cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return [
f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{backend_variant}-{cpu}-{os}"
f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{v}-{cpu}-{os}"
for v in backend_variants
]


Expand All @@ -48,7 +60,7 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]:

selected_backend = _select_backend(backend)

backend_variant = selected_backend.variant
backend_variants = _compatible_backend_variants(selected_backend)

import tvm_ffi

Expand All @@ -57,7 +69,8 @@ def _tvm_ffi_build_variant(backend: str | None) -> list[str]:
os = platform.system().lower()

return [
f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{backend_variant}-{cpu}-{os}"
f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{v}-{cpu}-{os}"
for v in backend_variants
]


Expand Down
131 changes: 131 additions & 0 deletions kernels/tests/test_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
from unittest.mock import patch

from packaging.version import Version

from kernels.backends import CANN, CPU, CUDA, XPU, Metal, Neuron, ROCm
from kernels.utils import _find_kernel_in_repo_path
from kernels.variants import _compatible_backend_variants


class TestCompatibleBackendVariants:
# Only CUDA for now.

def test_cuda_single_variant_minor_zero(self):
backend = CUDA(version=Version("12.0"))
assert _compatible_backend_variants(backend) == ["cu120"]

def test_cuda_returns_descending_minor_versions(self):
backend = CUDA(version=Version("12.9"))
variants = _compatible_backend_variants(backend)
assert variants == [f"cu12{i}" for i in range(9, -1, -1)]

def test_cuda_first_variant_is_exact_match(self):
backend = CUDA(version=Version("12.6"))
variants = _compatible_backend_variants(backend)
assert variants[0] == "cu126"

def test_cuda_last_variant_is_major_minor_zero(self):
backend = CUDA(version=Version("12.6"))
variants = _compatible_backend_variants(backend)
assert variants[-1] == "cu120"

def test_cuda_resolves_to_largest_available_minor(self):
# System has CUDA 12.9; available builds: 12.6, 12.8, 13.0.
# Expected: resolve to 12.8 (largest z <= 9 that is available).
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu126", "cu128", "cu130"}
resolved = next((v for v in candidates if v in available), None)
assert resolved == "cu128"

def test_cuda_falls_back_when_exact_not_available(self):
# System has CUDA 12.9; only 12.6 available for major 12.
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu126"}
resolved = next((v for v in candidates if v in available), None)
assert resolved == "cu126"

def test_cuda_no_match_when_only_higher_minor_available(self):
# System has CUDA 12.6; only 12.8 available — should not match.
backend = CUDA(version=Version("12.6"))
candidates = _compatible_backend_variants(backend)
available = {"cu128"}
resolved = next((v for v in candidates if v in available), None)
assert resolved is None

def test_cuda_no_match_for_different_major(self):
# System has CUDA 12.9; only 13.x builds available — should not match.
backend = CUDA(version=Version("12.9"))
candidates = _compatible_backend_variants(backend)
available = {"cu130", "cu131"}
resolved = next((v for v in candidates if v in available), None)
assert resolved is None

def test_cuda_major_version_preserved(self):
backend = CUDA(version=Version("11.8"))
variants = _compatible_backend_variants(backend)
assert all(v.startswith("cu11") for v in variants)

# Non-CUDA

def test_cpu_returns_single_variant(self):
assert _compatible_backend_variants(CPU()) == ["cpu"]

def test_metal_returns_single_variant(self):
assert _compatible_backend_variants(Metal()) == ["metal"]

def test_neuron_returns_single_variant(self):
assert _compatible_backend_variants(Neuron()) == ["neuron"]

def test_rocm_returns_single_variant(self):
backend = ROCm(version=Version("6.2"))
assert _compatible_backend_variants(backend) == ["rocm62"]

def test_xpu_returns_single_variant(self):
backend = XPU(version=Version("2024.2"))
assert _compatible_backend_variants(backend) == [backend.variant]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to put a literal here as well on the RHS of the equality check?


def test_cann_returns_single_variant(self):
backend = CANN(version=Version("8.0"))
assert _compatible_backend_variants(backend) == ["cann80"]


class TestVariantResolutionLogging:
def test_logs_when_resolved_to_lower_minor(self, tmp_path, caplog):
exact = "torch28-cxx11-cu129-x86_64-linux"
resolved = "torch28-cxx11-cu128-x86_64-linux"

build_dir = tmp_path / "build" / resolved
build_dir.mkdir(parents=True)
(build_dir / "__init__.py").touch()

cuda_129 = CUDA(version=Version("12.9"))
with (
patch("kernels.utils._build_variants", return_value=[exact, resolved]),
patch("kernels.utils._select_backend", return_value=cuda_129),
caplog.at_level(logging.INFO),
):
_find_kernel_in_repo_path(tmp_path, "test_kernel")

assert any(
"cu129" in r.message and resolved in r.message for r in caplog.records
)

def test_no_log_when_exact_match(self, tmp_path, caplog):
exact = "torch28-cxx11-cu129-x86_64-linux"

build_dir = tmp_path / "build" / exact
build_dir.mkdir(parents=True)
(build_dir / "__init__.py").touch()

cuda_129 = CUDA(version=Version("12.9"))
with (
patch("kernels.utils._build_variants", return_value=[exact]),
patch("kernels.utils._select_backend", return_value=cuda_129),
caplog.at_level(logging.INFO),
):
_find_kernel_in_repo_path(tmp_path, "test_kernel")

assert not any("resolved to" in r.message for r in caplog.records)
Loading