Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build_kernel_macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ jobs:
# kernels. Also run tests once we have a macOS runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )

- name: Build relu metal cpp kernel
run: ( cd examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
2 changes: 2 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ pub enum Dependencies {
Cutlass4_0,
#[serde(rename = "cutlass_sycl")]
CutlassSycl,
#[serde(rename = "metal-cpp")]
MetalCpp,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a version like cutlass? My first guess is not, since on Mac we always have everything the latest, but I thought I'd check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question! I agree we probably do not need a version here and we'll always prefer latest

Torch,
}

Expand Down
20 changes: 20 additions & 0 deletions examples/relu-metal-cpp/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[general]
name = "relu"
universal = false

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]


[kernel.relu_metal]
backend = "metal"
src = [
"relu.cpp",
"metallib_loader.mm",
"relu_cpp.metal",
"common.h",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nicer to have these in a directory, since it's an example, we want best practices :).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally agree, i've updated the folder struct and build.toml in latest

]
depends = [ "torch", "metal-cpp" ]
10 changes: 10 additions & 0 deletions examples/relu-metal-cpp/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef COMMON_H
#define COMMON_H

#include <metal_stdlib>
using namespace metal;

// Common constants and utilities for Metal kernels
constant float RELU_THRESHOLD = 0.0f;

#endif // COMMON_H
164 changes: 164 additions & 0 deletions examples/relu-metal-cpp/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions examples/relu-metal-cpp/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU metal cpp kernel";

inputs = {
kernel-builder.url = "path:../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs {
inherit self;
path = ./.;
};
}
41 changes: 41 additions & 0 deletions examples/relu-metal-cpp/metallib_loader.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>

#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif

// C++ interface to load the embedded metallib without exposing ObjC types
extern "C" {
void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) {
id<MTLDevice> mtlDevice = (__bridge id<MTLDevice>)device;
NSError* error = nil;

id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error);

if (!library && errorMsg && error) {
*errorMsg = strdup([error.localizedDescription UTF8String]);
}

// Manually retain since we're not using ARC
// The caller will wrap in NS::TransferPtr which assumes ownership
if (library) {
[library retain];
}
return (__bridge void*)library;
}

// Get PyTorch's MPS device (returns id<MTLDevice> as void*)
void* getMPSDevice() {
return (__bridge void*)at::mps::MPSDevice::getInstance()->device();
}

// Get PyTorch's current MPS command queue (returns id<MTLCommandQueue> as void*)
void* getMPSCommandQueue() {
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
}
}
119 changes: 119 additions & 0 deletions examples/relu-metal-cpp/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#define NS_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION

// Include metal-cpp headers from system
#include <Metal/Metal.hpp>
#include <Foundation/Foundation.hpp>
#include <Foundation/NSSharedPtr.hpp>

#include <torch/torch.h>

// C interface from metallib_loader.mm
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
extern "C" void* getMPSDevice();
extern "C" void* getMPSCommandQueue();

namespace {

MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) {
return reinterpret_cast<MTL::Buffer*>(const_cast<void*>(tensor.storage().data()));
}

NS::String* makeNSString(const std::string& value) {
return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding);
}

MTL::Library* loadLibrary(MTL::Device* device) {
const char* errorMsg = nullptr;
void* library = loadEmbeddedMetalLibrary(reinterpret_cast<void*>(device), &errorMsg);

TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ",
errorMsg ? errorMsg : "Unknown error");

if (errorMsg) {
free(const_cast<char*>(errorMsg));
}

return reinterpret_cast<MTL::Library*>(library);
}

} // namespace

void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
TORCH_CHECK(device != nullptr, "Failed to get MPS device");

MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");

MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);

const std::string kernelName =
std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
NS::SharedPtr<NS::String> kernelNameString = NS::TransferPtr(makeNSString(kernelName));

NS::SharedPtr<MTL::Function> computeFunction =
NS::TransferPtr(library->newFunction(kernelNameString.get()));
TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName);

NS::Error* pipelineError = nullptr;
NS::SharedPtr<MTL::ComputePipelineState> pipelineState =
NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError));
TORCH_CHECK(pipelineState.get() != nullptr,
"Failed to create compute pipeline state: ",
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");

// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");

MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");

encoder->setComputePipelineState(pipelineState.get());

auto* inputBuffer = getMTLBuffer(input);
auto* outputBuffer = getMTLBuffer(output);
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");

encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);

const NS::UInteger totalThreads = input.numel();
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
if (threadGroupSize > totalThreads) {
threadGroupSize = totalThreads;
}

const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);

encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
encoder->endEncoding();

commandBuffer->commit();
}

void relu(torch::Tensor& out, const torch::Tensor& input) {
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf,
"Unsupported data type: ", input.scalar_type());

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

dispatchReluKernel(input, out);
}
Loading
Loading