Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Claude Code
.claude/

# Byte-compiled / optimized / DLL files
__pycache__/
text_generation_server/__pycache__/
Expand Down
10 changes: 10 additions & 0 deletions build2cmake/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ pub enum Kernel {
cxx_flags: Option<Vec<String>>,
depends: Vec<Dependency>,
include: Option<Vec<String>>,
metal_std_version: Option<String>,
src: Vec<String>,
},
Rocm {
Expand Down Expand Up @@ -234,6 +235,15 @@ impl Kernel {
| Kernel::Xpu { src, .. } => src,
}
}

pub fn metal_std_version(&self) -> Option<&str> {
match self {
Kernel::Metal {
metal_std_version, ..
} => metal_std_version.as_deref(),
_ => None,
}
}
}

#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
Expand Down
3 changes: 3 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub enum Kernel {
cxx_flags: Option<Vec<String>>,
depends: Vec<Dependency>,
include: Option<Vec<String>>,
metal_std_version: Option<String>,
src: Vec<String>,
},
#[serde(rename_all = "kebab-case")]
Expand Down Expand Up @@ -232,11 +233,13 @@ impl From<Kernel> for super::Kernel {
cxx_flags,
depends,
include,
metal_std_version,
src,
} => super::Kernel::Metal {
cxx_flags,
depends,
include,
metal_std_version,
src,
},
Kernel::Rocm {
Expand Down
5 changes: 5 additions & 0 deletions build2cmake/src/config/v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pub enum Kernel {
cxx_flags: Option<Vec<String>>,
depends: Vec<Dependency>,
include: Option<Vec<String>>,
metal_std_version: Option<String>,
src: Vec<String>,
},
#[serde(rename_all = "kebab-case")]
Expand Down Expand Up @@ -261,11 +262,13 @@ impl From<Kernel> for super::Kernel {
cxx_flags,
depends,
include,
metal_std_version,
src,
} => super::Kernel::Metal {
cxx_flags,
depends,
include,
metal_std_version,
src,
},
Kernel::Rocm {
Expand Down Expand Up @@ -425,11 +428,13 @@ impl From<super::Kernel> for Kernel {
cxx_flags,
depends,
include,
metal_std_version,
src,
} => Kernel::Metal {
cxx_flags,
depends,
include,
metal_std_version,
src,
},
super::Kernel::Rocm {
Expand Down
7 changes: 6 additions & 1 deletion build2cmake/src/templates/kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ endfunction()

function(metal_kernel_component SRC_VAR)
set(options)
set(oneValueArgs)
set(oneValueArgs METAL_STD_VERSION)
set(multiValueArgs SOURCES INCLUDES CXX_FLAGS)
cmake_parse_arguments(KERNEL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

Expand Down Expand Up @@ -293,4 +293,9 @@ function(metal_kernel_component SRC_VAR)
list(APPEND _TMP_METAL_INCLUDES ${KERNEL_INCLUDES})
set(METAL_INCLUDE_DIRS ${_TMP_METAL_INCLUDES} PARENT_SCOPE)
endif()

# Propagate Metal std version to parent scope for compile_metal_shaders
if(KERNEL_METAL_STD_VERSION)
set(METAL_STD_VERSION ${KERNEL_METAL_STD_VERSION} PARENT_SCOPE)
endif()
endfunction()
10 changes: 8 additions & 2 deletions build2cmake/src/templates/metal/compile-metal.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS)
set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain")
endif()

# Set Metal compiler flags
set(METAL_FLAGS "-std=metal4.0" "-O2")
# Set Metal compiler flags.
# metal3.1 → air64_v26, macOS 14+
# metal3.2 → air64_v27, macOS 15+
# metal4.0 → air64_v28, macOS 26+
if(NOT DEFINED METAL_STD_VERSION)
set(METAL_STD_VERSION "metal4.0")
endif()
set(METAL_FLAGS "-std=${METAL_STD_VERSION}" "-O2")

# Output directory for compiled metallib
set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
Expand Down
1 change: 1 addition & 0 deletions build2cmake/src/templates/metal/kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ if(GPU_LANG STREQUAL "METAL")
SOURCES {{ sources }}
{% if includes %}INCLUDES "{{ includes }}"{% endif %}
{% if cxx_flags %}CXX_FLAGS "{{ cxx_flags }}"{% endif %}
{% if metal_std_version %}METAL_STD_VERSION "{{ metal_std_version }}"{% endif %}
)
endif()
1 change: 1 addition & 0 deletions build2cmake/src/torch/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ fn render_kernel_component_metal(
cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")),
includes => kernel.include().map(prefix_and_join_includes),
kernel_name => kernel_name,
metal_std_version => kernel.metal_std_version(),
sources => sources,
},
&mut *write,
Expand Down
1 change: 1 addition & 0 deletions builder/examples/relu-metal-cpp/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ src = [

[kernel.relu_metal]
backend = "metal"
metal-std-version = "metal3.1"
src = [
"relu/relu.cpp",
"relu/metallib_loader.mm",
Expand Down
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
# fail in a GPU-less sandbox. Even in that case, it's better to lazily
# load the part with this functionality.
doGetKernelCheck ? true,
pythonCheckInputs ? pkgs: [ ],
pythonCheckInputs ? pkgs: [ pkgs.kernels-test-utils ],
pythonNativeCheckInputs ? pkgs: [ ],
torchVersions ? _: torchVersions',
}:
Expand Down
9 changes: 9 additions & 0 deletions kernels-test-utils/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "kernels-test-utils"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = ["pytest", "torch"]
14 changes: 14 additions & 0 deletions kernels-test-utils/src/kernels_test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Shared test utilities for kernel repos."""

from kernels_test_utils.allclose import fp8_allclose
from kernels_test_utils.device import get_available_devices, get_device, skip_if_no_gpu
from kernels_test_utils.tolerances import DEFAULT_TOLERANCES, get_tolerances

__all__ = [
"fp8_allclose",
"get_available_devices",
"get_device",
"get_tolerances",
"skip_if_no_gpu",
"DEFAULT_TOLERANCES",
]
32 changes: 32 additions & 0 deletions kernels-test-utils/src/kernels_test_utils/allclose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Allclose variants that work around device limitations."""

import torch
from torch._prims_common import TensorLikeType


def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""``torch.allclose`` replacement that handles FP8 types and MPS.
On MPS (which lacks float64) the comparison is done in float32.
Everywhere else the tensors are promoted to float64.
"""
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)

if a.device.type == "mps" or b.device.type == "mps":
a_cmp = a.float()
b_cmp = b.float()
else:
a_cmp = a.double()
b_cmp = b.double()

return bool(
torch.all(
torch.isclose(a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan)
).item()
)
41 changes: 41 additions & 0 deletions kernels-test-utils/src/kernels_test_utils/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Device detection utilities for kernel tests."""

from typing import List

import pytest
import torch


def get_device() -> torch.device:
"""Return the best available compute device (MPS > CUDA > XPU > CPU)."""
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu")
return torch.device("cpu")


def get_available_devices() -> List[str]:
"""Return device strings suitable for pytest parametrization.

On MPS: ``["mps"]``
On CUDA: ``["cuda:0", "cuda:1", ...]`` for each visible GPU.
On XPU: ``["xpu:0", "xpu:1", ...]`` for each visible accelerator.
Fallback: ``["cpu"]``
"""
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return ["mps"]
if torch.cuda.is_available():
return [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
if hasattr(torch, "xpu") and torch.xpu.is_available():
return [f"xpu:{i}" for i in range(max(1, torch.xpu.device_count()))]
return ["cpu"]


def skip_if_no_gpu() -> None:
"""Call inside a test to skip when no GPU is available."""
dev = get_device()
if dev.type == "cpu":
pytest.skip("No GPU device available")
19 changes: 19 additions & 0 deletions kernels-test-utils/src/kernels_test_utils/tolerances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Default tolerance tables for kernel tests."""

from typing import Dict

import torch

DEFAULT_TOLERANCES: Dict[torch.dtype, Dict[str, float]] = {
torch.float32: {"atol": 1e-5, "rtol": 1e-5},
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
torch.bfloat16: {"atol": 1e-2, "rtol": 1.6e-2},
}


def get_tolerances(dtype: torch.dtype) -> Dict[str, float]:
"""Return ``{"atol": ..., "rtol": ...}`` for *dtype*.

Falls back to ``atol=0.1, rtol=0.1`` for unknown dtypes.
"""
return DEFAULT_TOLERANCES.get(dtype, {"atol": 0.1, "rtol": 0.1})
2 changes: 2 additions & 0 deletions nix/overlay.nix
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ in

kernels = callPackage ./pkgs/python-modules/kernels { };

kernels-test-utils = callPackage ./pkgs/python-modules/kernels-test-utils { };

pyclibrary = python-self.callPackage ./pkgs/python-modules/pyclibrary { };

mkTorch = callPackage ./pkgs/python-modules/torch/binary { };
Expand Down
42 changes: 42 additions & 0 deletions nix/pkgs/python-modules/kernels-test-utils/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
lib,
buildPythonPackage,
setuptools,

pytest,
torch,
}:

let
version =
(builtins.fromTOML (builtins.readFile ../../../../kernels-test-utils/pyproject.toml)).project.version;
in
buildPythonPackage {
pname = "kernels-test-utils";
inherit version;
pyproject = true;

src =
let
sourceFiles = file: file.hasExt "toml" || file.hasExt "py";
in
lib.fileset.toSource {
root = ../../../../kernels-test-utils;
fileset = lib.fileset.fileFilter sourceFiles ../../../../kernels-test-utils;
};

build-system = [ setuptools ];

dependencies = [
pytest
torch
];

pythonImportsCheck = [
"kernels_test_utils"
];

meta = with lib; {
description = "Shared test utilities for kernel repos";
};
}
13 changes: 3 additions & 10 deletions template/tests/test___KERNEL_NAME_NORMALIZED__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import platform

import torch

from kernels_test_utils import get_device

import __KERNEL_NAME_NORMALIZED__


def test___KERNEL_NAME_NORMALIZED__():
if platform.system() == "Darwin":
device = torch.device("mps")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
elif torch.version.cuda is not None and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
device = get_device()

x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
expected = x + 1.0
Expand Down