Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions builder/examples/relu-metal-cpp/relu/metallib_loader.mm
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,23 @@
void* getMPSCommandQueue() {
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
}

// Get the MPS stream's command encoder (returns id<MTLComputeCommandEncoder> as void*).
// Uses PyTorch's encoder lifecycle management (kernel coalescing).
void* getMPSCommandEncoder() {
return (__bridge void*)at::mps::getCurrentMPSStream()->commandEncoder();
}

// Commit the current command buffer and continue with a new one.
void mpsSynchronize() {
at::mps::getCurrentMPSStream()->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE);
}

// Dispatch a block on the MPS stream's serial queue.
void mpsDispatchSync(void (*block)(void* ctx), void* ctx) {
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
dispatch_sync(stream->queue(), ^{
block(ctx);
});
}
}
72 changes: 44 additions & 28 deletions builder/examples/relu-metal-cpp/relu/relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
extern "C" void* getMPSDevice();
extern "C" void* getMPSCommandQueue();
extern "C" void* getMPSCommandEncoder();
extern "C" void mpsSynchronize();
extern "C" void mpsDispatchSync(void (*block)(void* ctx), void* ctx);

namespace {

Expand Down Expand Up @@ -38,14 +41,42 @@ MTL::Library* loadLibrary(MTL::Device* device) {

} // namespace

// Context passed through mpsDispatchSync to the dispatch block
struct ReluDispatchContext {
MTL::ComputePipelineState* pipelineState;
MTL::Buffer* inputBuffer;
MTL::Buffer* outputBuffer;
NS::UInteger inputOffset;
NS::UInteger outputOffset;
NS::UInteger totalThreads;
};

static void reluDispatchBlock(void* ctx) {
auto* c = reinterpret_cast<ReluDispatchContext*>(ctx);

// Use PyTorch's MPS stream encoder (kernel coalescing)
MTL::ComputeCommandEncoder* encoder =
reinterpret_cast<MTL::ComputeCommandEncoder*>(getMPSCommandEncoder());
TORCH_CHECK(encoder != nullptr, "Failed to get MPS compute encoder");

encoder->setComputePipelineState(c->pipelineState);
encoder->setBuffer(c->inputBuffer, c->inputOffset, 0);
encoder->setBuffer(c->outputBuffer, c->outputOffset, 1);

NS::UInteger threadGroupSize = c->pipelineState->maxTotalThreadsPerThreadgroup();
if (threadGroupSize > c->totalThreads) {
threadGroupSize = c->totalThreads;
}

encoder->dispatchThreads(
MTL::Size::Make(c->totalThreads, 1, 1),
MTL::Size::Make(threadGroupSize, 1, 1));
}

void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
TORCH_CHECK(device != nullptr, "Failed to get MPS device");

MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");

MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);

Expand All @@ -64,36 +95,21 @@ void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
"Failed to create compute pipeline state: ",
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");

// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");

MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");

encoder->setComputePipelineState(pipelineState.get());

auto* inputBuffer = getMTLBuffer(input);
auto* outputBuffer = getMTLBuffer(output);
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");

encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);

const NS::UInteger totalThreads = input.numel();
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
if (threadGroupSize > totalThreads) {
threadGroupSize = totalThreads;
}

const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);

encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
encoder->endEncoding();
ReluDispatchContext ctx{
pipelineState.get(),
inputBuffer,
outputBuffer,
static_cast<NS::UInteger>(input.storage_offset() * input.element_size()),
static_cast<NS::UInteger>(output.storage_offset() * output.element_size()),
static_cast<NS::UInteger>(input.numel())};

commandBuffer->commit();
mpsDispatchSync(reluDispatchBlock, &ctx);
mpsSynchronize();
}

void relu(torch::Tensor& out, const torch::Tensor& input) {
Expand Down