diff --git a/builder/examples/relu-metal-cpp/relu/metallib_loader.mm b/builder/examples/relu-metal-cpp/relu/metallib_loader.mm index 9e63d909..b7ad4a2c 100644 --- a/builder/examples/relu-metal-cpp/relu/metallib_loader.mm +++ b/builder/examples/relu-metal-cpp/relu/metallib_loader.mm @@ -37,4 +37,23 @@ void* getMPSCommandQueue() { return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue(); } + + // Get the MPS stream's command encoder (returns id as void*). + // Uses PyTorch's encoder lifecycle management (kernel coalescing). + void* getMPSCommandEncoder() { + return (__bridge void*)at::mps::getCurrentMPSStream()->commandEncoder(); + } + + // Commit the current command buffer and continue with a new one. + void mpsSynchronize() { + at::mps::getCurrentMPSStream()->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + } + + // Dispatch a block on the MPS stream's serial queue. + void mpsDispatchSync(void (*block)(void* ctx), void* ctx) { + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + dispatch_sync(stream->queue(), ^{ + block(ctx); + }); + } } diff --git a/builder/examples/relu-metal-cpp/relu/relu.cpp b/builder/examples/relu-metal-cpp/relu/relu.cpp index 85b9ac57..49cdb0f9 100644 --- a/builder/examples/relu-metal-cpp/relu/relu.cpp +++ b/builder/examples/relu-metal-cpp/relu/relu.cpp @@ -11,6 +11,9 @@ extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); extern "C" void* getMPSDevice(); extern "C" void* getMPSCommandQueue(); +extern "C" void* getMPSCommandEncoder(); +extern "C" void mpsSynchronize(); +extern "C" void mpsDispatchSync(void (*block)(void* ctx), void* ctx); namespace { @@ -38,14 +41,42 @@ MTL::Library* loadLibrary(MTL::Device* device) { } // namespace +// Context passed through mpsDispatchSync to the dispatch block +struct ReluDispatchContext { + MTL::ComputePipelineState* pipelineState; + MTL::Buffer* inputBuffer; + MTL::Buffer* outputBuffer; + NS::UInteger inputOffset; + NS::UInteger outputOffset; + NS::UInteger totalThreads; +}; + +static void reluDispatchBlock(void* ctx) { + auto* c = reinterpret_cast(ctx); + + // Use PyTorch's MPS stream encoder (kernel coalescing) + MTL::ComputeCommandEncoder* encoder = + reinterpret_cast(getMPSCommandEncoder()); + TORCH_CHECK(encoder != nullptr, "Failed to get MPS compute encoder"); + + encoder->setComputePipelineState(c->pipelineState); + encoder->setBuffer(c->inputBuffer, c->inputOffset, 0); + encoder->setBuffer(c->outputBuffer, c->outputOffset, 1); + + NS::UInteger threadGroupSize = c->pipelineState->maxTotalThreadsPerThreadgroup(); + if (threadGroupSize > c->totalThreads) { + threadGroupSize = c->totalThreads; + } + + encoder->dispatchThreads( + MTL::Size::Make(c->totalThreads, 1, 1), + MTL::Size::Make(threadGroupSize, 1, 1)); +} + void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { - // Use PyTorch's MPS device and command queue (these are borrowed references, not owned) MTL::Device* device = reinterpret_cast(getMPSDevice()); TORCH_CHECK(device != nullptr, "Failed to get MPS device"); - MTL::CommandQueue* commandQueue = reinterpret_cast(getMPSCommandQueue()); - TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue"); - MTL::Library* libraryPtr = reinterpret_cast(loadLibrary(device)); NS::SharedPtr library = NS::TransferPtr(libraryPtr); @@ -64,36 +95,21 @@ void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { "Failed to create compute pipeline state: ", pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); - // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue - MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer(); - TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer"); - - MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder(); - TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder"); - - encoder->setComputePipelineState(pipelineState.get()); - auto* inputBuffer = getMTLBuffer(input); auto* outputBuffer = getMTLBuffer(output); TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null"); TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null"); - encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); - encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); - - const NS::UInteger totalThreads = input.numel(); - NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup(); - if (threadGroupSize > totalThreads) { - threadGroupSize = totalThreads; - } - - const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1); - const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1); - - encoder->dispatchThreads(gridSize, threadsPerThreadgroup); - encoder->endEncoding(); + ReluDispatchContext ctx{ + pipelineState.get(), + inputBuffer, + outputBuffer, + static_cast(input.storage_offset() * input.element_size()), + static_cast(output.storage_offset() * output.element_size()), + static_cast(input.numel())}; - commandBuffer->commit(); + mpsDispatchSync(reluDispatchBlock, &ctx); + mpsSynchronize(); } void relu(torch::Tensor& out, const torch::Tensor& input) {