Skip to content

Commit c220611

Browse files
committed
Add kernels-test-utils shared Python package
Create a shared test utilities package that consolidates duplicated device detection, tolerance tables, and allclose helpers across all kernel repos. The package is automatically available in all kernel dev/test shells via the default pythonCheckInputs. Modules: - device: get_device(), get_available_devices(), skip_if_no_gpu() - tolerances: DEFAULT_TOLERANCES dict, get_tolerances(dtype) - allclose: fp8_allclose() with MPS float64 workaround Wired into nix overlay and set as default pythonCheckInputs in genKernelFlakeOutputs so downstream repos get it automatically. Updated template test to use kernels_test_utils imports. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
1 parent 9e2b45c commit c220611

File tree

9 files changed

+163
-11
lines changed

9 files changed

+163
-11
lines changed

flake.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
# fail in a GPU-less sandbox. Even in that case, it's better to lazily
9191
# load the part with this functionality.
9292
doGetKernelCheck ? true,
93-
pythonCheckInputs ? pkgs: [ ],
93+
pythonCheckInputs ? pkgs: [ pkgs.kernels-test-utils ],
9494
pythonNativeCheckInputs ? pkgs: [ ],
9595
torchVersions ? _: torchVersions',
9696
}:

kernels-test-utils/pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[build-system]
2+
requires = ["setuptools"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "kernels-test-utils"
7+
version = "0.1.0"
8+
requires-python = ">=3.10"
9+
dependencies = ["pytest", "torch"]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Shared test utilities for kernel repos."""
2+
3+
from kernels_test_utils.allclose import fp8_allclose
4+
from kernels_test_utils.device import get_available_devices, get_device, skip_if_no_gpu
5+
from kernels_test_utils.tolerances import DEFAULT_TOLERANCES, get_tolerances
6+
7+
__all__ = [
8+
"fp8_allclose",
9+
"get_available_devices",
10+
"get_device",
11+
"get_tolerances",
12+
"skip_if_no_gpu",
13+
"DEFAULT_TOLERANCES",
14+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Allclose variants that work around device limitations."""
2+
3+
import torch
4+
from torch._prims_common import TensorLikeType
5+
6+
7+
def fp8_allclose(
8+
a: TensorLikeType,
9+
b: TensorLikeType,
10+
rtol: float = 1e-05,
11+
atol: float = 1e-08,
12+
equal_nan: bool = False,
13+
) -> bool:
14+
"""``torch.allclose`` replacement that handles FP8 types and MPS.
15+
16+
On MPS (which lacks float64) the comparison is done in float32.
17+
Everywhere else the tensors are promoted to float64.
18+
"""
19+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
20+
21+
if a.device.type == "mps" or b.device.type == "mps":
22+
a_cmp = a.float()
23+
b_cmp = b.float()
24+
else:
25+
a_cmp = a.double()
26+
b_cmp = b.double()
27+
28+
return bool(
29+
torch.all(
30+
torch.isclose(a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan)
31+
).item()
32+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Device detection utilities for kernel tests."""
2+
3+
from typing import List
4+
5+
import pytest
6+
import torch
7+
8+
9+
def get_device() -> torch.device:
10+
"""Return the best available compute device (MPS > CUDA > XPU > CPU)."""
11+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
12+
return torch.device("mps")
13+
if torch.cuda.is_available():
14+
return torch.device("cuda")
15+
if hasattr(torch, "xpu") and torch.xpu.is_available():
16+
return torch.device("xpu")
17+
return torch.device("cpu")
18+
19+
20+
def get_available_devices() -> List[str]:
21+
"""Return device strings suitable for pytest parametrization.
22+
23+
On MPS: ``["mps"]``
24+
On CUDA: ``["cuda:0", "cuda:1", ...]`` for each visible GPU.
25+
On XPU: ``["xpu:0", "xpu:1", ...]`` for each visible accelerator.
26+
Fallback: ``["cpu"]``
27+
"""
28+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29+
return ["mps"]
30+
if torch.cuda.is_available():
31+
return [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
32+
if hasattr(torch, "xpu") and torch.xpu.is_available():
33+
return [f"xpu:{i}" for i in range(max(1, torch.xpu.device_count()))]
34+
return ["cpu"]
35+
36+
37+
def skip_if_no_gpu() -> None:
38+
"""Call inside a test to skip when no GPU is available."""
39+
dev = get_device()
40+
if dev.type == "cpu":
41+
pytest.skip("No GPU device available")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Default tolerance tables for kernel tests."""
2+
3+
from typing import Dict
4+
5+
import torch
6+
7+
DEFAULT_TOLERANCES: Dict[torch.dtype, Dict[str, float]] = {
8+
torch.float32: {"atol": 1e-5, "rtol": 1e-5},
9+
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
10+
torch.bfloat16: {"atol": 1e-2, "rtol": 1.6e-2},
11+
}
12+
13+
14+
def get_tolerances(dtype: torch.dtype) -> Dict[str, float]:
15+
"""Return ``{"atol": ..., "rtol": ...}`` for *dtype*.
16+
17+
Falls back to ``atol=0.1, rtol=0.1`` for unknown dtypes.
18+
"""
19+
return DEFAULT_TOLERANCES.get(dtype, {"atol": 0.1, "rtol": 0.1})

nix/overlay.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ in
8383

8484
kernels = callPackage ./pkgs/python-modules/kernels { };
8585

86+
kernels-test-utils = callPackage ./pkgs/python-modules/kernels-test-utils { };
87+
8688
pyclibrary = python-self.callPackage ./pkgs/python-modules/pyclibrary { };
8789

8890
mkTorch = callPackage ./pkgs/python-modules/torch/binary { };
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
lib,
3+
buildPythonPackage,
4+
setuptools,
5+
6+
pytest,
7+
torch,
8+
}:
9+
10+
let
11+
version =
12+
(builtins.fromTOML (builtins.readFile ../../../../kernels-test-utils/pyproject.toml)).project.version;
13+
in
14+
buildPythonPackage {
15+
pname = "kernels-test-utils";
16+
inherit version;
17+
pyproject = true;
18+
19+
src =
20+
let
21+
sourceFiles = file: file.hasExt "toml" || file.hasExt "py";
22+
in
23+
lib.fileset.toSource {
24+
root = ../../../../kernels-test-utils;
25+
fileset = lib.fileset.fileFilter sourceFiles ../../../../kernels-test-utils;
26+
};
27+
28+
build-system = [ setuptools ];
29+
30+
dependencies = [
31+
pytest
32+
torch
33+
];
34+
35+
pythonImportsCheck = [
36+
"kernels_test_utils"
37+
];
38+
39+
meta = with lib; {
40+
description = "Shared test utilities for kernel repos";
41+
};
42+
}

template/tests/test___KERNEL_NAME_NORMALIZED__.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1-
import platform
2-
31
import torch
42

3+
from kernels_test_utils import get_device
4+
55
import __KERNEL_NAME_NORMALIZED__
66

77

88
def test___KERNEL_NAME_NORMALIZED__():
9-
if platform.system() == "Darwin":
10-
device = torch.device("mps")
11-
elif hasattr(torch, "xpu") and torch.xpu.is_available():
12-
device = torch.device("xpu")
13-
elif torch.version.cuda is not None and torch.cuda.is_available():
14-
device = torch.device("cuda")
15-
else:
16-
device = torch.device("cpu")
9+
device = get_device()
1710

1811
x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
1912
expected = x + 1.0

0 commit comments

Comments
 (0)