Skip to content

Commit 1f02571

Browse files
committed
feature: add mps kernel
1 parent 7a99b3b commit 1f02571

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
kernel void puff_advantage_kernel(
5+
device const float* values [[buffer(0)]],
6+
device const float* rewards [[buffer(1)]],
7+
device const float* dones [[buffer(2)]],
8+
device const float* importance [[buffer(3)]],
9+
device float* advantages [[buffer(4)]],
10+
constant float& gamma [[buffer(5)]],
11+
constant float& lambda [[buffer(6)]],
12+
constant float& rho_clip [[buffer(7)]],
13+
constant float& c_clip [[buffer(8)]],
14+
constant int& horizon [[buffer(9)]],
15+
uint row [[thread_position_in_grid]])
16+
{
17+
int offset = row * horizon;
18+
device const float* row_values = values + offset;
19+
device const float* row_rewards = rewards + offset;
20+
device const float* row_dones = dones + offset;
21+
device const float* row_importance = importance + offset;
22+
device float* row_advantages = advantages + offset;
23+
24+
float gamma_lambda = gamma * lambda;
25+
26+
float lastpufferlam = 0.0f;
27+
for (int t = horizon - 2; t >= 0; t--) {
28+
int t_next = t + 1;
29+
30+
float importance_t = row_importance[t];
31+
float done_next = row_dones[t_next];
32+
float value_t = row_values[t];
33+
float value_next = row_values[t_next];
34+
float reward_next = row_rewards[t_next];
35+
36+
float rho_t = fmin(importance_t, rho_clip);
37+
float c_t = fmin(importance_t, c_clip);
38+
39+
float nextnonterminal = 1.0f - done_next;
40+
float delta = rho_t * (reward_next + gamma * value_next * nextnonterminal - value_t);
41+
lastpufferlam = delta + gamma_lambda * c_t * lastpufferlam * nextnonterminal;
42+
row_advantages[t] = lastpufferlam;
43+
}
44+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#import <Metal/Metal.h>
2+
#import <Foundation/Foundation.h>
3+
#include <torch/extension.h>
4+
5+
namespace pufferlib {
6+
7+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
8+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
9+
}
10+
11+
void compute_puff_advantage_mps(torch::Tensor values, torch::Tensor rewards,
12+
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
13+
double gamma, double lambda, double rho_clip, double c_clip) {
14+
15+
@autoreleasepool {
16+
TORCH_CHECK(values.device().is_mps(), "All tensors must be on MPS device");
17+
TORCH_CHECK(values.is_contiguous(), "values must be contiguous");
18+
TORCH_CHECK(rewards.is_contiguous(), "rewards must be contiguous");
19+
TORCH_CHECK(dones.is_contiguous(), "dones must be contiguous");
20+
TORCH_CHECK(importance.is_contiguous(), "importance must be contiguous");
21+
TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous");
22+
TORCH_CHECK(values.scalar_type() == torch::kFloat32, "All tensors must be float32");
23+
24+
int num_steps = values.size(0);
25+
int horizon = values.size(1);
26+
27+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
28+
NSError* error = nil;
29+
30+
// probably not all too necessary to cache, but does save like 0.1ms per call
31+
static id<MTLFunction> function = nil;
32+
static id<MTLComputePipelineState> pipelineState = nil;
33+
34+
if (function == nil) {
35+
// read the file & compile the shader
36+
NSString* sourcePath = [[@(__FILE__) stringByDeletingLastPathComponent]
37+
stringByAppendingPathComponent:@"pufferlib.metal"];
38+
NSString* source = [NSString stringWithContentsOfFile:sourcePath
39+
encoding:NSUTF8StringEncoding error:&error];
40+
TORCH_CHECK(source, "Failed to read Metal source file: ",
41+
error ? [[error localizedDescription] UTF8String] : "unknown error");
42+
43+
id<MTLLibrary> library = [device newLibraryWithSource:source options:nil error:&error];
44+
TORCH_CHECK(library, "Failed to compile Metal library: ",
45+
[[error localizedDescription] UTF8String]);
46+
47+
function = [library newFunctionWithName:@"puff_advantage_kernel"];
48+
TORCH_CHECK(function, "Failed to find puff_advantage_kernel function");
49+
50+
pipelineState = [device newComputePipelineStateWithFunction:function error:&error];
51+
TORCH_CHECK(pipelineState, "Failed to create compute pipeline: ",
52+
[[error localizedDescription] UTF8String]);
53+
}
54+
55+
56+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
57+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
58+
59+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
60+
61+
dispatch_sync(serialQueue, ^{
62+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
63+
TORCH_CHECK(encoder, "Failed to create compute command encoder");
64+
65+
[encoder setComputePipelineState:pipelineState];
66+
[encoder setBuffer:getMTLBufferStorage(values)
67+
offset:values.storage_offset() * values.element_size() atIndex:0];
68+
[encoder setBuffer:getMTLBufferStorage(rewards)
69+
offset:rewards.storage_offset() * rewards.element_size() atIndex:1];
70+
[encoder setBuffer:getMTLBufferStorage(dones)
71+
offset:dones.storage_offset() * dones.element_size() atIndex:2];
72+
[encoder setBuffer:getMTLBufferStorage(importance)
73+
offset:importance.storage_offset() * importance.element_size() atIndex:3];
74+
[encoder setBuffer:getMTLBufferStorage(advantages)
75+
offset:advantages.storage_offset() * advantages.element_size() atIndex:4];
76+
77+
float gamma_f = gamma, lambda_f = lambda, rho_clip_f = rho_clip, c_clip_f = c_clip;
78+
int horizon_i = horizon;
79+
80+
[encoder setBytes:&gamma_f length:sizeof(float) atIndex:5];
81+
[encoder setBytes:&lambda_f length:sizeof(float) atIndex:6];
82+
[encoder setBytes:&rho_clip_f length:sizeof(float) atIndex:7];
83+
[encoder setBytes:&c_clip_f length:sizeof(float) atIndex:8];
84+
[encoder setBytes:&horizon_i length:sizeof(int) atIndex:9];
85+
86+
MTLSize gridSize = MTLSizeMake(num_steps, 1, 1);
87+
88+
NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup;
89+
if (threadGroupSize > num_steps) {
90+
threadGroupSize = num_steps;
91+
}
92+
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
93+
94+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
95+
[encoder endEncoding];
96+
97+
torch::mps::commit();
98+
});
99+
}
100+
}
101+
102+
TORCH_LIBRARY_IMPL(pufferlib, MPS, m) {
103+
m.impl("compute_puff_advantage", &compute_puff_advantage_mps);
104+
}
105+
106+
}

setup.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
CUDA_HOME,
2222
ROCM_HOME
2323
)
24+
from torch.backends import mps
2425

2526
# build cuda extension if torch can find CUDA or HIP/ROCM in the system
2627
# may require `uv pip install --no-build-isolation` or `python setup.py build_ext --inplace`
2728
BUID_CUDA_EXT = bool(CUDA_HOME or ROCM_HOME)
2829

30+
# build mps extension if torch can find MPS in the system
31+
BUILD_MPS_EXT = bool(mps.is_available())
32+
2933
# Build with DEBUG=1 to enable debug symbols
3034
DEBUG = os.getenv("DEBUG", "0") == "1"
3135
NO_OCEAN = os.getenv("NO_OCEAN", "0") == "1"
@@ -243,6 +247,9 @@ def run(self):
243247
if BUID_CUDA_EXT:
244248
extension = CUDAExtension
245249
torch_sources.append("pufferlib/extensions/cuda/pufferlib.cu")
250+
elif BUILD_MPS_EXT:
251+
extension = CppExtension
252+
torch_sources.append("pufferlib/extensions/mps/pufferlib.mm")
246253
else:
247254
extension = CppExtension
248255

0 commit comments

Comments
 (0)