Skip to content

Commit b727140

Browse files
committed
Fix MPS encoder lifecycle in relu-metal-cpp example
Add C bridge functions (getMPSCommandEncoder, mpsSynchronize, mpsDispatchSync) to metallib_loader.mm so the C++ metal-cpp example can properly integrate with PyTorch's MPS stream encoder lifecycle without needing ObjC++ code in the main kernel file. Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)
1 parent 3ae5790 commit b727140

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

builder/examples/relu-metal-cpp/relu/metallib_loader.mm

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,23 @@
3737
void* getMPSCommandQueue() {
3838
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
3939
}
40+
41+
// Get the MPS stream's command encoder (returns id<MTLComputeCommandEncoder> as void*).
42+
// Uses PyTorch's encoder lifecycle management (kernel coalescing).
43+
void* getMPSCommandEncoder() {
44+
return (__bridge void*)at::mps::getCurrentMPSStream()->commandEncoder();
45+
}
46+
47+
// Commit the current command buffer and continue with a new one.
48+
void mpsSynchronize() {
49+
at::mps::getCurrentMPSStream()->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
50+
}
51+
52+
// Dispatch a block on the MPS stream's serial queue.
53+
void mpsDispatchSync(void (*block)(void* ctx), void* ctx) {
54+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
55+
dispatch_sync(stream->queue(), ^{
56+
block(ctx);
57+
});
58+
}
4059
}

builder/examples/relu-metal-cpp/relu/relu.cpp

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
1212
extern "C" void* getMPSDevice();
1313
extern "C" void* getMPSCommandQueue();
14+
extern "C" void* getMPSCommandEncoder();
15+
extern "C" void mpsSynchronize();
16+
extern "C" void mpsDispatchSync(void (*block)(void* ctx), void* ctx);
1417

1518
namespace {
1619

@@ -38,14 +41,42 @@ MTL::Library* loadLibrary(MTL::Device* device) {
3841

3942
} // namespace
4043

44+
// Context passed through mpsDispatchSync to the dispatch block
45+
struct ReluDispatchContext {
46+
MTL::ComputePipelineState* pipelineState;
47+
MTL::Buffer* inputBuffer;
48+
MTL::Buffer* outputBuffer;
49+
NS::UInteger inputOffset;
50+
NS::UInteger outputOffset;
51+
NS::UInteger totalThreads;
52+
};
53+
54+
static void reluDispatchBlock(void* ctx) {
55+
auto* c = reinterpret_cast<ReluDispatchContext*>(ctx);
56+
57+
// Use PyTorch's MPS stream encoder (kernel coalescing)
58+
MTL::ComputeCommandEncoder* encoder =
59+
reinterpret_cast<MTL::ComputeCommandEncoder*>(getMPSCommandEncoder());
60+
TORCH_CHECK(encoder != nullptr, "Failed to get MPS compute encoder");
61+
62+
encoder->setComputePipelineState(c->pipelineState);
63+
encoder->setBuffer(c->inputBuffer, c->inputOffset, 0);
64+
encoder->setBuffer(c->outputBuffer, c->outputOffset, 1);
65+
66+
NS::UInteger threadGroupSize = c->pipelineState->maxTotalThreadsPerThreadgroup();
67+
if (threadGroupSize > c->totalThreads) {
68+
threadGroupSize = c->totalThreads;
69+
}
70+
71+
encoder->dispatchThreads(
72+
MTL::Size::Make(c->totalThreads, 1, 1),
73+
MTL::Size::Make(threadGroupSize, 1, 1));
74+
}
75+
4176
void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
42-
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
4377
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
4478
TORCH_CHECK(device != nullptr, "Failed to get MPS device");
4579

46-
MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
47-
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");
48-
4980
MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
5081
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);
5182

@@ -64,36 +95,21 @@ void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
6495
"Failed to create compute pipeline state: ",
6596
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");
6697

67-
// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
68-
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
69-
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");
70-
71-
MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
72-
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");
73-
74-
encoder->setComputePipelineState(pipelineState.get());
75-
7698
auto* inputBuffer = getMTLBuffer(input);
7799
auto* outputBuffer = getMTLBuffer(output);
78100
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
79101
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");
80102

81-
encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
82-
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);
83-
84-
const NS::UInteger totalThreads = input.numel();
85-
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
86-
if (threadGroupSize > totalThreads) {
87-
threadGroupSize = totalThreads;
88-
}
89-
90-
const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
91-
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);
92-
93-
encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
94-
encoder->endEncoding();
103+
ReluDispatchContext ctx{
104+
pipelineState.get(),
105+
inputBuffer,
106+
outputBuffer,
107+
static_cast<NS::UInteger>(input.storage_offset() * input.element_size()),
108+
static_cast<NS::UInteger>(output.storage_offset() * output.element_size()),
109+
static_cast<NS::UInteger>(input.numel())};
95110

96-
commandBuffer->commit();
111+
mpsDispatchSync(reluDispatchBlock, &ctx);
112+
mpsSynchronize();
97113
}
98114

99115
void relu(torch::Tensor& out, const torch::Tensor& input) {

0 commit comments

Comments
 (0)