1111extern " C" void * loadEmbeddedMetalLibrary (void * device, const char ** errorMsg);
1212extern " C" void * getMPSDevice ();
1313extern " C" void * getMPSCommandQueue ();
14+ extern " C" void * getMPSCommandEncoder ();
15+ extern " C" void mpsSynchronize ();
16+ extern " C" void mpsDispatchSync (void (*block)(void * ctx), void* ctx);
1417
1518namespace {
1619
@@ -38,14 +41,42 @@ MTL::Library* loadLibrary(MTL::Device* device) {
3841
3942} // namespace
4043
44+ // Context passed through mpsDispatchSync to the dispatch block
45+ struct ReluDispatchContext {
46+ MTL::ComputePipelineState* pipelineState;
47+ MTL::Buffer* inputBuffer;
48+ MTL::Buffer* outputBuffer;
49+ NS::UInteger inputOffset;
50+ NS::UInteger outputOffset;
51+ NS::UInteger totalThreads;
52+ };
53+
54+ static void reluDispatchBlock (void * ctx) {
55+ auto * c = reinterpret_cast <ReluDispatchContext*>(ctx);
56+
57+ // Use PyTorch's MPS stream encoder (kernel coalescing)
58+ MTL::ComputeCommandEncoder* encoder =
59+ reinterpret_cast <MTL::ComputeCommandEncoder*>(getMPSCommandEncoder ());
60+ TORCH_CHECK (encoder != nullptr , " Failed to get MPS compute encoder" );
61+
62+ encoder->setComputePipelineState (c->pipelineState );
63+ encoder->setBuffer (c->inputBuffer , c->inputOffset , 0 );
64+ encoder->setBuffer (c->outputBuffer , c->outputOffset , 1 );
65+
66+ NS::UInteger threadGroupSize = c->pipelineState ->maxTotalThreadsPerThreadgroup ();
67+ if (threadGroupSize > c->totalThreads ) {
68+ threadGroupSize = c->totalThreads ;
69+ }
70+
71+ encoder->dispatchThreads (
72+ MTL::Size::Make (c->totalThreads , 1 , 1 ),
73+ MTL::Size::Make (threadGroupSize, 1 , 1 ));
74+ }
75+
4176void dispatchReluKernel (const torch::Tensor& input, torch::Tensor& output) {
42- // Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
4377 MTL::Device* device = reinterpret_cast <MTL::Device*>(getMPSDevice ());
4478 TORCH_CHECK (device != nullptr , " Failed to get MPS device" );
4579
46- MTL::CommandQueue* commandQueue = reinterpret_cast <MTL::CommandQueue*>(getMPSCommandQueue ());
47- TORCH_CHECK (commandQueue != nullptr , " Failed to get MPS command queue" );
48-
4980 MTL::Library* libraryPtr = reinterpret_cast <MTL::Library*>(loadLibrary (device));
5081 NS::SharedPtr<MTL::Library> library = NS::TransferPtr (libraryPtr);
5182
@@ -64,36 +95,21 @@ void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
6495 "Failed to create compute pipeline state: ",
6596 pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");
6697
67- // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
68- MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer ();
69- TORCH_CHECK (commandBuffer != nullptr , " Failed to create Metal command buffer" );
70-
71- MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder ();
72- TORCH_CHECK (encoder != nullptr , " Failed to create compute command encoder" );
73-
74- encoder->setComputePipelineState (pipelineState.get());
75-
7698 auto * inputBuffer = getMTLBuffer(input);
7799 auto * outputBuffer = getMTLBuffer(output);
78100 TORCH_CHECK (inputBuffer != nullptr , " Input buffer is null" );
79101 TORCH_CHECK (outputBuffer != nullptr , " Output buffer is null" );
80102
81- encoder->setBuffer (inputBuffer, input.storage_offset () * input.element_size (), 0 );
82- encoder->setBuffer (outputBuffer, output.storage_offset () * output.element_size (), 1 );
83-
84- const NS::UInteger totalThreads = input.numel ();
85- NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup ();
86- if (threadGroupSize > totalThreads) {
87- threadGroupSize = totalThreads;
88- }
89-
90- const MTL::Size gridSize = MTL::Size::Make (totalThreads, 1 , 1 );
91- const MTL::Size threadsPerThreadgroup = MTL::Size::Make (threadGroupSize, 1 , 1 );
92-
93- encoder->dispatchThreads (gridSize, threadsPerThreadgroup);
94- encoder->endEncoding ();
103+ ReluDispatchContext ctx{
104+ pipelineState.get (),
105+ inputBuffer,
106+ outputBuffer,
107+ static_cast <NS::UInteger>(input.storage_offset () * input.element_size ()),
108+ static_cast <NS::UInteger>(output.storage_offset () * output.element_size ()),
109+ static_cast <NS::UInteger>(input.numel ())};
95110
96- commandBuffer->commit ();
111+ mpsDispatchSync (reluDispatchBlock, &ctx);
112+ mpsSynchronize ();
97113}
98114
99115void relu (torch::Tensor& out, const torch::Tensor& input) {
0 commit comments