Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions .github/workflows/build_kernel_macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,44 @@ on:

jobs:
build:
name: Build kernel
runs-on: macos-26
name: Build and test kernel (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- os: macos-14-xlarge
xcode: "/Applications/Xcode_15.4.app"
# macOS 14 is best-effort: builds work but MPS tests may OOM
# on runners with limited unified memory.
allow-failure: true
- os: macos-15-xlarge
xcode: "/Applications/Xcode_16.2.app"
- os: macos-26-xlarge
xcode: "/Applications/Xcode_26.0.app"
continue-on-error: ${{ matrix.allow-failure || false }}
steps:
- name: "Select Xcode"
run: sudo xcrun xcode-select -s /Applications/Xcode_26.0.app
run: sudo xcrun xcode-select -s ${{ matrix.xcode }}
- name: "Install Metal Toolchain"
if: matrix.os == 'macos-26-xlarge'
run: xcodebuild -downloadComponent metalToolchain
- uses: actions/checkout@v6
- uses: cachix/install-nix-action@v31
with:
extra_nix_config: |
sandbox = relaxed
- uses: cachix/cachix-action@v16
with:
name: huggingface
#authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
# For now we only test that there are no regressions in building macOS
# kernels. Also run tests once we have a macOS runner.

- name: Build relu kernel
run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
- name: Test relu kernel
run: ( cd builder/examples/relu && nix develop .\#test --command pytest tests/ -v )

- name: Build relu metal cpp kernel
run: ( cd builder/examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
- name: Test relu metal cpp kernel
run: ( cd builder/examples/relu-metal-cpp && nix develop .\#test --command pytest tests/ -v )
31 changes: 22 additions & 9 deletions build2cmake/src/templates/metal/compile-metal.cmake
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
# Metal shader compilation function
function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS)
if(NOT DEFINED METAL_TOOLCHAIN)
# Try the separate Metal toolchain first (macOS 26+ with downloadable component)
execute_process(
COMMAND "xcodebuild" "-showComponent" "MetalToolchain"
OUTPUT_VARIABLE FIND_METAL_OUT
RESULT_VARIABLE FIND_METAL_ERROR_CODE
ERROR_VARIABLE FIND_METAL_STDERR
OUTPUT_STRIP_TRAILING_WHITESPACE)

if(NOT FIND_METAL_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "${ERR_MSG}: ${FIND_METAL_STDERR}")
if(FIND_METAL_ERROR_CODE EQUAL 0)
string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}")
set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain")
else()
# Fall back to the default Xcode toolchain (macOS 14/15 bundle metal in Xcode)
execute_process(
COMMAND "xcode-select" "-p"
OUTPUT_VARIABLE XCODE_DEV_DIR
RESULT_VARIABLE XCODE_SELECT_ERROR
OUTPUT_STRIP_TRAILING_WHITESPACE)

if(XCODE_SELECT_ERROR EQUAL 0)
set(METAL_TOOLCHAIN "${XCODE_DEV_DIR}/Toolchains/XcodeDefault.xctoolchain")
else()
message(FATAL_ERROR "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain")
endif()
endif()

# Extract the Toolchain Search Path value and append Metal.xctoolchain
string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}")
set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain")
endif()

# Set Metal compiler flags
set(METAL_FLAGS "-std=metal4.0" "-O2")
# Set Metal compiler flags.
# metal3.1 → air64_v26, macOS 14+
# metal3.2 → air64_v27, macOS 15+
# metal4.0 → air64_v28, macOS 26+
set(METAL_FLAGS "-std=metal3.1" "-O2")

# Output directory for compiled metallib
set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
Expand Down
25 changes: 12 additions & 13 deletions builder/examples/extra-data/relu_metal/relu.mm
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/mps/MPSStream.h>
#include <torch/torch.h>

#import <Foundation/Foundation.h>
Expand All @@ -18,8 +19,10 @@
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
torch::Tensor &output) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
TORCH_CHECK(stream, "Failed to get MPS stream");

id<MTLDevice> device = stream->device();
int numThreads = input.numel();

// Load the embedded Metal library from memory
Expand All @@ -44,14 +47,12 @@
error:&error];
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);

id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");

dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();

dispatch_sync(serialQueue, ^() {
id<MTLComputeCommandEncoder> computeEncoder =
[commandBuffer computeCommandEncoder];
// Use stream->commandEncoder() to properly integrate with PyTorch's
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
// when the kernel is called twice in sequence.
dispatch_sync(stream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");

[computeEncoder setComputePipelineState:reluPSO];
Expand All @@ -72,11 +73,9 @@

[computeEncoder dispatchThreads:gridSize
threadsPerThreadgroup:threadgroupSize];

[computeEncoder endEncoding];

torch::mps::commit();
});

stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
}

return output;
Expand Down
25 changes: 12 additions & 13 deletions builder/examples/relu/relu_metal/relu.mm
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/mps/MPSStream.h>
#include <torch/torch.h>

#import <Foundation/Foundation.h>
Expand All @@ -18,8 +19,10 @@
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
torch::Tensor &output) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
TORCH_CHECK(stream, "Failed to get MPS stream");

id<MTLDevice> device = stream->device();
int numThreads = input.numel();

// Load the embedded Metal library from memory
Expand All @@ -44,14 +47,12 @@
error:&error];
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);

id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");

dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();

dispatch_sync(serialQueue, ^() {
id<MTLComputeCommandEncoder> computeEncoder =
[commandBuffer computeCommandEncoder];
// Use stream->commandEncoder() to properly integrate with PyTorch's
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
// when the kernel is called twice in sequence.
dispatch_sync(stream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");

[computeEncoder setComputePipelineState:reluPSO];
Expand All @@ -72,11 +73,9 @@

[computeEncoder dispatchThreads:gridSize
threadsPerThreadgroup:threadgroupSize];

[computeEncoder endEncoding];

torch::mps::commit();
});

stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
}

return output;
Expand Down
52 changes: 44 additions & 8 deletions builder/lib/torch-extension/arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ let
# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
# It's not supported by the nixpkgs shim.
xcrunHost = writeScriptBin "xcrunHost" ''
# Use system SDK for Metal files.
# Use system SDK for Metal files. Clear Nix-set variables that
# interfere with xcrun/xcodebuild's SDK and toolchain resolution.
unset DEVELOPER_DIR
/usr/bin/xcrun $@
unset SDKROOT
/usr/bin/xcrun "$@"
'';

metalSupport = buildConfig.metal or false;
Expand Down Expand Up @@ -123,13 +125,47 @@ stdenv.mkDerivation (prevAttrs: {
# instead, we'll use showComponent (which will emit a lot of warnings due
# to the above) to grab the path of the Metal toolchain.
lib.optionalString metalSupport ''
METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p")
if [ ! -d "$METAL_PATH" ]; then
>&2 echo "Cannot find Metal toolchain, use: xcodebuild -downloadComponent MetalToolchain"
exit 1
# Try the separate Metal toolchain first (macOS 26+ with xcodebuild -downloadComponent).
# Use || true to prevent set -o pipefail from aborting on older macOS where
# -showComponent is unsupported.
METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p" || true)

if [ -d "$METAL_PATH/Metal.xctoolchain" ]; then
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain")
else
# On macOS 14/15, xcrun and xcode-select may not work inside the Nix
# build environment (sandbox restrictions). Try them, then fall back
# to scanning /Applications for Xcode installations.
XCODE_DEV=$(${xcrunHost}/bin/xcrunHost xcode-select -p 2>/dev/null || true)
XCODE_TOOLCHAIN="$XCODE_DEV/Toolchains/XcodeDefault.xctoolchain"

XCRUN_METAL=$(${xcrunHost}/bin/xcrunHost xcrun -find metal 2>/dev/null || true)

if [ -d "$XCODE_TOOLCHAIN/usr/bin" ] && [ -f "$XCODE_TOOLCHAIN/usr/bin/metal" ]; then
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$XCODE_TOOLCHAIN")
elif [ -n "$XCRUN_METAL" ] && [ -f "$XCRUN_METAL" ]; then
# Derive toolchain path from xcrun result
METAL_BIN_DIR=$(dirname "$XCRUN_METAL")
METAL_TC_DIR=$(dirname $(dirname "$METAL_BIN_DIR"))
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_TC_DIR")
else
# Last resort: scan /Applications/Xcode*.app for metal compiler
FOUND_TC=""
for xcode_app in /Applications/Xcode*.app; do
TC="$xcode_app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
if [ -f "$TC/usr/bin/metal" ]; then
FOUND_TC="$TC"
break
fi
done
if [ -n "$FOUND_TC" ]; then
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$FOUND_TC")
else
>&2 echo "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain"
exit 1
fi
fi
fi

cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain")
'';

# hipify copies files, but its target is run in the CMake build and install
Expand Down
26 changes: 15 additions & 11 deletions docs/source/builder/metal.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,29 @@ Instructions on this page assume that you installed Nix with the

## Targeted macOS versions

Since new macOS versions get [adopted quickly](https://telemetrydeck.com/survey/apple/macOS/versions/),
we only support the latest major macOS version except for the first weeks
after a release, when we also support the previous major version.
Metal kernels are compiled with `-std=metal3.1` (AIR v26), which requires
macOS 15 or later on ARM64 (Apple Silicon).

We currently support macOS 26.0 and later on ARM64 (Apple silicon).
| macOS version | Support level |
|---------------|---------------|
| macOS 26+ | Fully supported and tested in CI |
| macOS 15 | Fully supported and tested in CI |
| macOS 14 | Best-effort (builds work, some tests may fail due to MPS memory limits) |

## Requirements

To build a Metal kernel, the following requirements must be met:

- Xcode 26.x must be available on the build machine.
- `xcode-select -p` must point to the Xcode 26 installation, typically
- An Xcode installation with the Metal compiler must be available. The build
system automatically detects the Metal toolchain from available Xcode
installations.
- On macOS 26+, the Metal Toolchain is a separate download from Xcode:
`xcodebuild -downloadComponent MetalToolchain`
- On macOS 14/15, Metal ships bundled with Xcode (no separate install needed).
- `xcode-select -p` must point to your Xcode installation, typically
`/Applications/Xcode.app/Contents/Developer`. If this is not the case,
you can set the path with:
`sudo xcode-select -s /path/to/Xcode.app/Contents/Developer`
- The Metal Toolchain must be installed. Starting with macOS 26, this is
a separate download from Xcode. You can install it with:
`xcodebuild -downloadComponent MetalToolchain`
- The Nix sandbox should be set to `relaxed`, because the Nix derivation
that builds the kernel must have access to Xcode and the Metal Toolchain.
You can verify this by checking that `/etc/nix/nix.custom.conf` contains
Expand All @@ -47,8 +52,7 @@ Xcode 26.1
Build version 17B55
```

The reported version must be 26.0 or newer. Then you can validate that the
Metal Toolchain is installed with:
On macOS 26+, you can validate that the Metal Toolchain is installed with:

```bash
$ xcodebuild -showComponent metalToolchain
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/mps/MPSStream.h>
#include <torch/torch.h>

#import <Foundation/Foundation.h>
Expand Down Expand Up @@ -25,7 +26,10 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
"Tensors must be on same device");

@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
TORCH_CHECK(stream, "Failed to get MPS stream");

id<MTLDevice> device = stream->device();
int numThreads = input.numel();

NSError *error = nil;
Expand All @@ -42,9 +46,12 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
[device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(pso, error.localizedDescription.UTF8String);

id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
dispatch_sync(torch::mps::get_dispatch_queue(), ^() {
id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder];
// Use stream->commandEncoder() to properly integrate with PyTorch's
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
// when the kernel is called twice in sequence.
dispatch_sync(stream->queue(), ^() {
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
[encoder setComputePipelineState:pso];
[encoder setBuffer:getMTLBufferStorage(input)
offset:input.storage_offset() * input.element_size()
Expand All @@ -57,8 +64,8 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads);
[encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1)
threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)];
[encoder endEncoding];
torch::mps::commit();
});

stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
}
}