|
| 1 | +#import <Metal/Metal.h> |
| 2 | +#import <Foundation/Foundation.h> |
| 3 | +#include <torch/extension.h> |
| 4 | + |
| 5 | +namespace pufferlib { |
| 6 | + |
| 7 | +static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) { |
| 8 | + return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data()); |
| 9 | +} |
| 10 | + |
| 11 | +void compute_puff_advantage_mps(torch::Tensor values, torch::Tensor rewards, |
| 12 | + torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages, |
| 13 | + double gamma, double lambda, double rho_clip, double c_clip) { |
| 14 | + |
| 15 | + @autoreleasepool { |
| 16 | + TORCH_CHECK(values.device().is_mps(), "All tensors must be on MPS device"); |
| 17 | + TORCH_CHECK(values.is_contiguous(), "values must be contiguous"); |
| 18 | + TORCH_CHECK(rewards.is_contiguous(), "rewards must be contiguous"); |
| 19 | + TORCH_CHECK(dones.is_contiguous(), "dones must be contiguous"); |
| 20 | + TORCH_CHECK(importance.is_contiguous(), "importance must be contiguous"); |
| 21 | + TORCH_CHECK(advantages.is_contiguous(), "advantages must be contiguous"); |
| 22 | + TORCH_CHECK(values.scalar_type() == torch::kFloat32, "All tensors must be float32"); |
| 23 | + |
| 24 | + int num_steps = values.size(0); |
| 25 | + int horizon = values.size(1); |
| 26 | + |
| 27 | + id<MTLDevice> device = MTLCreateSystemDefaultDevice(); |
| 28 | + NSError* error = nil; |
| 29 | + |
| 30 | + // probably not all too necessary to cache, but does save like 0.1ms per call |
| 31 | + static id<MTLFunction> function = nil; |
| 32 | + static id<MTLComputePipelineState> pipelineState = nil; |
| 33 | + |
| 34 | + if (function == nil) { |
| 35 | + // read the file & compile the shader |
| 36 | + NSString* sourcePath = [[@(__FILE__) stringByDeletingLastPathComponent] |
| 37 | + stringByAppendingPathComponent:@"pufferlib.metal"]; |
| 38 | + NSString* source = [NSString stringWithContentsOfFile:sourcePath |
| 39 | + encoding:NSUTF8StringEncoding error:&error]; |
| 40 | + TORCH_CHECK(source, "Failed to read Metal source file: ", |
| 41 | + error ? [[error localizedDescription] UTF8String] : "unknown error"); |
| 42 | + |
| 43 | + id<MTLLibrary> library = [device newLibraryWithSource:source options:nil error:&error]; |
| 44 | + TORCH_CHECK(library, "Failed to compile Metal library: ", |
| 45 | + [[error localizedDescription] UTF8String]); |
| 46 | + |
| 47 | + function = [library newFunctionWithName:@"puff_advantage_kernel"]; |
| 48 | + TORCH_CHECK(function, "Failed to find puff_advantage_kernel function"); |
| 49 | + |
| 50 | + pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; |
| 51 | + TORCH_CHECK(pipelineState, "Failed to create compute pipeline: ", |
| 52 | + [[error localizedDescription] UTF8String]); |
| 53 | + } |
| 54 | + |
| 55 | + |
| 56 | + id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer(); |
| 57 | + TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); |
| 58 | + |
| 59 | + dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); |
| 60 | + |
| 61 | + dispatch_sync(serialQueue, ^{ |
| 62 | + id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder]; |
| 63 | + TORCH_CHECK(encoder, "Failed to create compute command encoder"); |
| 64 | + |
| 65 | + [encoder setComputePipelineState:pipelineState]; |
| 66 | + [encoder setBuffer:getMTLBufferStorage(values) |
| 67 | + offset:values.storage_offset() * values.element_size() atIndex:0]; |
| 68 | + [encoder setBuffer:getMTLBufferStorage(rewards) |
| 69 | + offset:rewards.storage_offset() * rewards.element_size() atIndex:1]; |
| 70 | + [encoder setBuffer:getMTLBufferStorage(dones) |
| 71 | + offset:dones.storage_offset() * dones.element_size() atIndex:2]; |
| 72 | + [encoder setBuffer:getMTLBufferStorage(importance) |
| 73 | + offset:importance.storage_offset() * importance.element_size() atIndex:3]; |
| 74 | + [encoder setBuffer:getMTLBufferStorage(advantages) |
| 75 | + offset:advantages.storage_offset() * advantages.element_size() atIndex:4]; |
| 76 | + |
| 77 | + float gamma_f = gamma, lambda_f = lambda, rho_clip_f = rho_clip, c_clip_f = c_clip; |
| 78 | + int horizon_i = horizon; |
| 79 | + |
| 80 | + [encoder setBytes:&gamma_f length:sizeof(float) atIndex:5]; |
| 81 | + [encoder setBytes:&lambda_f length:sizeof(float) atIndex:6]; |
| 82 | + [encoder setBytes:&rho_clip_f length:sizeof(float) atIndex:7]; |
| 83 | + [encoder setBytes:&c_clip_f length:sizeof(float) atIndex:8]; |
| 84 | + [encoder setBytes:&horizon_i length:sizeof(int) atIndex:9]; |
| 85 | + |
| 86 | + MTLSize gridSize = MTLSizeMake(num_steps, 1, 1); |
| 87 | + |
| 88 | + NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup; |
| 89 | + if (threadGroupSize > num_steps) { |
| 90 | + threadGroupSize = num_steps; |
| 91 | + } |
| 92 | + MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); |
| 93 | + |
| 94 | + [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; |
| 95 | + [encoder endEncoding]; |
| 96 | + |
| 97 | + torch::mps::commit(); |
| 98 | + }); |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +TORCH_LIBRARY_IMPL(pufferlib, MPS, m) { |
| 103 | + m.impl("compute_puff_advantage", &compute_puff_advantage_mps); |
| 104 | +} |
| 105 | + |
| 106 | +} |
0 commit comments