Skip to content

Commit b92aa5e

Browse files
committed
Fix Metal MPS encoder lifecycle and broaden macOS compatibility
Use stream->commandEncoder() instead of creating encoders directly via [cmdBuf computeCommandEncoder] to properly integrate with PyTorch's MPS stream encoder lifecycle management (kernel coalescing). Direct encoder creation bypasses the stream's internal _commandEncoder state and crashes on sequential kernel dispatches. Lower the default Metal standard from metal3.2 (macOS 15+) to metal3.1 (macOS 14+) since all current kernel features (bfloat16_t, simd_sum, simd_shuffle, threadgroup_barrier) are available in Metal 3.1. Add multi-strategy Metal toolchain detection for macOS 14+: - Separate Metal toolchain component (macOS 26+ cryptex mount) - xcrun/xcode-select based detection - Direct /Applications/Xcode*.app filesystem scan fallback Also clear SDKROOT in xcrunHost to prevent Nix-set SDK paths from interfering with system xcrun. Fixes: #307 Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)
1 parent 20d9f69 commit b92aa5e

File tree

5 files changed

+103
-49
lines changed

5 files changed

+103
-49
lines changed

build2cmake/src/templates/metal/compile-metal.cmake

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
11
# Metal shader compilation function
22
function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS)
33
if(NOT DEFINED METAL_TOOLCHAIN)
4+
# Try the separate Metal toolchain first (macOS 26+ with downloadable component)
45
execute_process(
56
COMMAND "xcodebuild" "-showComponent" "MetalToolchain"
67
OUTPUT_VARIABLE FIND_METAL_OUT
78
RESULT_VARIABLE FIND_METAL_ERROR_CODE
8-
ERROR_VARIABLE FIND_METAL_STDERR
99
OUTPUT_STRIP_TRAILING_WHITESPACE)
1010

11-
if(NOT FIND_METAL_ERROR_CODE EQUAL 0)
12-
message(FATAL_ERROR "${ERR_MSG}: ${FIND_METAL_STDERR}")
11+
if(FIND_METAL_ERROR_CODE EQUAL 0)
12+
string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}")
13+
set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain")
14+
else()
15+
# Fall back to the default Xcode toolchain (macOS 14/15 bundle metal in Xcode)
16+
execute_process(
17+
COMMAND "xcode-select" "-p"
18+
OUTPUT_VARIABLE XCODE_DEV_DIR
19+
RESULT_VARIABLE XCODE_SELECT_ERROR
20+
OUTPUT_STRIP_TRAILING_WHITESPACE)
21+
22+
if(XCODE_SELECT_ERROR EQUAL 0)
23+
set(METAL_TOOLCHAIN "${XCODE_DEV_DIR}/Toolchains/XcodeDefault.xctoolchain")
24+
else()
25+
message(FATAL_ERROR "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain")
26+
endif()
1327
endif()
14-
15-
# Extract the Toolchain Search Path value and append Metal.xctoolchain
16-
string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}")
17-
set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain")
1828
endif()
1929

20-
# Set Metal compiler flags
21-
set(METAL_FLAGS "-std=metal4.0" "-O2")
30+
# Set Metal compiler flags.
31+
# metal3.1 → air64_v26, macOS 14+
32+
# metal3.2 → air64_v27, macOS 15+
33+
# metal4.0 → air64_v28, macOS 26+
34+
set(METAL_FLAGS "-std=metal3.1" "-O2")
2235

2336
# Output directory for compiled metallib
2437
set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")

builder/examples/extra-data/relu_metal/relu.mm

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/mps/MPSStream.h>
12
#include <torch/torch.h>
23

34
#import <Foundation/Foundation.h>
@@ -18,8 +19,10 @@
1819
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
1920
torch::Tensor &output) {
2021
@autoreleasepool {
21-
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
22+
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
23+
TORCH_CHECK(stream, "Failed to get MPS stream");
2224

25+
id<MTLDevice> device = stream->device();
2326
int numThreads = input.numel();
2427

2528
// Load the embedded Metal library from memory
@@ -44,14 +47,12 @@
4447
error:&error];
4548
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
4649

47-
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
48-
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
49-
50-
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
51-
52-
dispatch_sync(serialQueue, ^() {
53-
id<MTLComputeCommandEncoder> computeEncoder =
54-
[commandBuffer computeCommandEncoder];
50+
// Use stream->commandEncoder() to properly integrate with PyTorch's
51+
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
52+
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
53+
// when the kernel is called twice in sequence.
54+
dispatch_sync(stream->queue(), ^() {
55+
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
5556
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
5657

5758
[computeEncoder setComputePipelineState:reluPSO];
@@ -72,11 +73,9 @@
7273

7374
[computeEncoder dispatchThreads:gridSize
7475
threadsPerThreadgroup:threadgroupSize];
75-
76-
[computeEncoder endEncoding];
77-
78-
torch::mps::commit();
7976
});
77+
78+
stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
8079
}
8180

8281
return output;

builder/examples/relu/relu_metal/relu.mm

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/mps/MPSStream.h>
12
#include <torch/torch.h>
23

34
#import <Foundation/Foundation.h>
@@ -18,8 +19,10 @@
1819
torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
1920
torch::Tensor &output) {
2021
@autoreleasepool {
21-
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
22+
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
23+
TORCH_CHECK(stream, "Failed to get MPS stream");
2224

25+
id<MTLDevice> device = stream->device();
2326
int numThreads = input.numel();
2427

2528
// Load the embedded Metal library from memory
@@ -44,14 +47,12 @@
4447
error:&error];
4548
TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
4649

47-
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
48-
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
49-
50-
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
51-
52-
dispatch_sync(serialQueue, ^() {
53-
id<MTLComputeCommandEncoder> computeEncoder =
54-
[commandBuffer computeCommandEncoder];
50+
// Use stream->commandEncoder() to properly integrate with PyTorch's
51+
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
52+
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
53+
// when the kernel is called twice in sequence.
54+
dispatch_sync(stream->queue(), ^() {
55+
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
5556
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
5657

5758
[computeEncoder setComputePipelineState:reluPSO];
@@ -72,11 +73,9 @@
7273

7374
[computeEncoder dispatchThreads:gridSize
7475
threadsPerThreadgroup:threadgroupSize];
75-
76-
[computeEncoder endEncoding];
77-
78-
torch::mps::commit();
7976
});
77+
78+
stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
8079
}
8180

8281
return output;

builder/lib/torch-extension/arch.nix

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@ let
8686
# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
8787
# It's not supported by the nixpkgs shim.
8888
xcrunHost = writeScriptBin "xcrunHost" ''
89-
# Use system SDK for Metal files.
89+
# Use system SDK for Metal files. Clear Nix-set variables that
90+
# interfere with xcrun/xcodebuild's SDK and toolchain resolution.
9091
unset DEVELOPER_DIR
91-
/usr/bin/xcrun $@
92+
unset SDKROOT
93+
/usr/bin/xcrun "$@"
9294
'';
9395

9496
metalSupport = buildConfig.metal or false;
@@ -123,13 +125,47 @@ stdenv.mkDerivation (prevAttrs: {
123125
# instead, we'll use showComponent (which will emit a lot of warnings due
124126
# to the above) to grab the path of the Metal toolchain.
125127
lib.optionalString metalSupport ''
126-
METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p")
127-
if [ ! -d "$METAL_PATH" ]; then
128-
>&2 echo "Cannot find Metal toolchain, use: xcodebuild -downloadComponent MetalToolchain"
129-
exit 1
128+
# Try the separate Metal toolchain first (macOS 26+ with xcodebuild -downloadComponent).
129+
# Use || true to prevent set -o pipefail from aborting on older macOS where
130+
# -showComponent is unsupported.
131+
METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p" || true)
132+
133+
if [ -d "$METAL_PATH/Metal.xctoolchain" ]; then
134+
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain")
135+
else
136+
# On macOS 14/15, xcrun and xcode-select may not work inside the Nix
137+
# build environment (sandbox restrictions). Try them, then fall back
138+
# to scanning /Applications for Xcode installations.
139+
XCODE_DEV=$(${xcrunHost}/bin/xcrunHost xcode-select -p 2>/dev/null || true)
140+
XCODE_TOOLCHAIN="$XCODE_DEV/Toolchains/XcodeDefault.xctoolchain"
141+
142+
XCRUN_METAL=$(${xcrunHost}/bin/xcrunHost xcrun -find metal 2>/dev/null || true)
143+
144+
if [ -d "$XCODE_TOOLCHAIN/usr/bin" ] && [ -f "$XCODE_TOOLCHAIN/usr/bin/metal" ]; then
145+
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$XCODE_TOOLCHAIN")
146+
elif [ -n "$XCRUN_METAL" ] && [ -f "$XCRUN_METAL" ]; then
147+
# Derive toolchain path from xcrun result
148+
METAL_BIN_DIR=$(dirname "$XCRUN_METAL")
149+
METAL_TC_DIR=$(dirname $(dirname "$METAL_BIN_DIR"))
150+
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_TC_DIR")
151+
else
152+
# Last resort: scan /Applications/Xcode*.app for metal compiler
153+
FOUND_TC=""
154+
for xcode_app in /Applications/Xcode*.app; do
155+
TC="$xcode_app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
156+
if [ -f "$TC/usr/bin/metal" ]; then
157+
FOUND_TC="$TC"
158+
break
159+
fi
160+
done
161+
if [ -n "$FOUND_TC" ]; then
162+
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$FOUND_TC")
163+
else
164+
>&2 echo "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain"
165+
exit 1
166+
fi
167+
fi
130168
fi
131-
132-
cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain")
133169
'';
134170

135171
# hipify copies files, but its target is run in the CMake build and install

template/__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/mps/MPSStream.h>
12
#include <torch/torch.h>
23

34
#import <Foundation/Foundation.h>
@@ -25,7 +26,10 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
2526
"Tensors must be on same device");
2627

2728
@autoreleasepool {
28-
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
29+
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
30+
TORCH_CHECK(stream, "Failed to get MPS stream");
31+
32+
id<MTLDevice> device = stream->device();
2933
int numThreads = input.numel();
3034

3135
NSError *error = nil;
@@ -42,9 +46,12 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
4246
[device newComputePipelineStateWithFunction:func error:&error];
4347
TORCH_CHECK(pso, error.localizedDescription.UTF8String);
4448

45-
id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
46-
dispatch_sync(torch::mps::get_dispatch_queue(), ^() {
47-
id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder];
49+
// Use stream->commandEncoder() to properly integrate with PyTorch's
50+
// MPS encoder lifecycle (kernel coalescing). Creating encoders directly
51+
// via [commandBuffer computeCommandEncoder] bypasses this and crashes
52+
// when the kernel is called twice in sequence.
53+
dispatch_sync(stream->queue(), ^() {
54+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
4855
[encoder setComputePipelineState:pso];
4956
[encoder setBuffer:getMTLBufferStorage(input)
5057
offset:input.storage_offset() * input.element_size()
@@ -57,8 +64,8 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input)
5764
MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads);
5865
[encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1)
5966
threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)];
60-
[encoder endEncoding];
61-
torch::mps::commit();
6267
});
68+
69+
stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
6370
}
6471
}

0 commit comments

Comments
 (0)