|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#ifdef __OBJC__ |
| 12 | +#import <Foundation/Foundation.h> |
| 13 | +#import <Metal/Metal.h> |
| 14 | +#include <dispatch/dispatch.h> |
| 15 | +// Forward declarations for MetalPerformanceShadersGraph types |
| 16 | +@class MPSGraph; |
| 17 | +@class MPSCommandBuffer; |
| 18 | +// Metal type definitions for Objective-C compilation |
| 19 | +typedef id<MTLDevice> MTLDevice_t; |
| 20 | +typedef id<MTLCommandQueue> MTLCommandQueue_t; |
| 21 | +typedef id<MTLCommandBuffer> MTLCommandBuffer_t; |
| 22 | +typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t; |
| 23 | +typedef id<MTLComputePipelineState> MTLComputePipelineState_t; |
| 24 | +typedef id<MTLFunction> MTLFunction_t; |
| 25 | +typedef id<MTLLibrary> MTLLibrary_t; |
| 26 | +typedef id<MTLBuffer> MTLBuffer_t; |
| 27 | +typedef dispatch_queue_t dispatch_queue_t; |
| 28 | +typedef MPSGraph* MPSGraph_t; |
| 29 | +typedef MPSCommandBuffer* MPSCommandBuffer_t; |
| 30 | +typedef NSDictionary* NSDictionary_t; |
| 31 | +#else |
| 32 | +// Forward declarations for C++ compilation |
| 33 | +typedef void* MTLDevice_t; |
| 34 | +typedef void* MTLCommandQueue_t; |
| 35 | +typedef void* MTLCommandBuffer_t; |
| 36 | +typedef void* MTLComputeCommandEncoder_t; |
| 37 | +typedef void* MTLComputePipelineState_t; |
| 38 | +typedef void* MTLFunction_t; |
| 39 | +typedef void* MTLLibrary_t; |
| 40 | +typedef void* MTLBuffer_t; |
| 41 | +typedef void* dispatch_queue_t; |
| 42 | +typedef void* MPSGraph_t; |
| 43 | +typedef void* MPSCommandBuffer_t; |
| 44 | +typedef void* NSDictionary_t; |
| 45 | +#endif |
| 46 | + |
| 47 | +#include <functional> |
| 48 | +#include <memory> |
| 49 | +#include <string> |
| 50 | +#include <unordered_map> |
| 51 | +#include <vector> |
| 52 | + |
| 53 | +namespace executorch::runtime::etensor { |
| 54 | +class Tensor; |
| 55 | +} |
| 56 | + |
| 57 | +namespace executorch { |
| 58 | +namespace backends { |
| 59 | +namespace metal { |
| 60 | + |
| 61 | +// Forward declarations |
| 62 | +class ETMetalKernelFunction; |
| 63 | +class ETMetalStream; |
| 64 | + |
| 65 | +// ======================= |
| 66 | +// SyncType - Metal synchronization options |
| 67 | +// ======================= |
| 68 | +enum class SyncType { |
| 69 | + NONE, // no commit to command buffer |
| 70 | + COMMIT, // commit and flush the command buffer |
| 71 | + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish |
| 72 | + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command |
| 73 | + // buffer |
| 74 | + COMMIT_ADAPTIVE, // commit adaptively based on available memory |
| 75 | +}; |
| 76 | + |
| 77 | +// ======================= |
| 78 | +// ETMetalShaderLibrary - ExecuTorch Metal shader library management |
| 79 | +// ======================= |
| 80 | +class ETMetalShaderLibrary { |
| 81 | + public: |
| 82 | + ETMetalShaderLibrary(const std::string& source); |
| 83 | + ~ETMetalShaderLibrary(); |
| 84 | + |
| 85 | + std::shared_ptr<ETMetalKernelFunction> getKernelFunction( |
| 86 | + const std::string& name); |
| 87 | + |
| 88 | + private: |
| 89 | + void compileLibrary(); |
| 90 | + std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState( |
| 91 | + const std::string& functionName); |
| 92 | + |
| 93 | + friend class ETMetalKernelFunction; |
| 94 | + |
| 95 | + std::string shaderSource_; |
| 96 | + MTLLibrary_t library_; |
| 97 | + std::unordered_map< |
| 98 | + std::string, |
| 99 | + std::pair<MTLComputePipelineState_t, MTLFunction_t>> |
| 100 | + pipelineStates_; |
| 101 | +}; |
| 102 | + |
| 103 | +// ======================= |
| 104 | +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution |
| 105 | +// ======================= |
| 106 | +class ETMetalKernelFunction { |
| 107 | + public: |
| 108 | + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); |
| 109 | + ~ETMetalKernelFunction(); |
| 110 | + |
| 111 | + void startEncoding(); |
| 112 | + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); |
| 113 | + void setArg(unsigned idx, int64_t val); |
| 114 | + |
| 115 | + void dispatchSingle(uint64_t length); |
| 116 | + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); |
| 117 | + void dispatchArray(const uint64_t* length, size_t length_size); |
| 118 | + void dispatchArrayWithGroupSize( |
| 119 | + const uint64_t* length, |
| 120 | + size_t length_size, |
| 121 | + const uint64_t* group_size, |
| 122 | + size_t group_size_size); |
| 123 | + |
| 124 | + void runCommandBlock(std::function<void(void)> f); |
| 125 | + |
| 126 | + private: |
| 127 | + MTLComputePipelineState_t cps_; |
| 128 | + MTLFunction_t func_; |
| 129 | + MTLComputeCommandEncoder_t encoder_; |
| 130 | +}; |
| 131 | + |
| 132 | +// ======================= |
| 133 | +// ETMetalStream - Metal command buffer and synchronization management |
| 134 | +// ======================= |
| 135 | +class ETMetalStream { |
| 136 | + public: |
| 137 | + ETMetalStream(); |
| 138 | + ~ETMetalStream(); |
| 139 | + |
| 140 | + // Get the default stream (singleton) |
| 141 | + static ETMetalStream* getDefaultStream(); |
| 142 | + |
| 143 | + // Device and queue access |
| 144 | + MTLDevice_t device() const { |
| 145 | + return device_; |
| 146 | + } |
| 147 | + MTLCommandQueue_t commandQueue() const { |
| 148 | + return commandQueue_; |
| 149 | + } |
| 150 | + dispatch_queue_t queue() const { |
| 151 | + return serialQueue_; |
| 152 | + } |
| 153 | + |
| 154 | + // Synchronization methods |
| 155 | + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); |
| 156 | + void synchronize(); // Overload for backward compatibility |
| 157 | + bool isEmpty() const; |
| 158 | + |
| 159 | + // Command buffer management with lazy creation |
| 160 | + MPSCommandBuffer_t commandBuffer(); |
| 161 | + MTLComputeCommandEncoder_t commandEncoder(); |
| 162 | + |
| 163 | + void endKernelCoalescing(); |
| 164 | + |
| 165 | + // MPSGraph execution |
| 166 | + void executeMPSGraph( |
| 167 | + MPSGraph_t mpsGraph, |
| 168 | + NSDictionary_t feeds, |
| 169 | + NSDictionary_t results, |
| 170 | + SyncType syncType = SyncType::COMMIT_ADAPTIVE); |
| 171 | + |
| 172 | + // Command buffer lifecycle management |
| 173 | + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); |
| 174 | + void flush(); |
| 175 | + |
| 176 | + // Memory operations |
| 177 | + void fill( |
| 178 | + MTLBuffer_t buffer, |
| 179 | + uint8_t value, |
| 180 | + size_t length, |
| 181 | + size_t offset, |
| 182 | + SyncType syncType = SyncType::NONE); |
| 183 | + void copy( |
| 184 | + MTLBuffer_t srcBuffer, |
| 185 | + MTLBuffer_t dstBuffer, |
| 186 | + size_t length, |
| 187 | + size_t srcOffset, |
| 188 | + size_t dstOffset, |
| 189 | + SyncType syncType = SyncType::NONE); |
| 190 | + |
| 191 | + private: |
| 192 | + // Private synchronization methods |
| 193 | + void commit(); |
| 194 | + void commitAndWait(); |
| 195 | + void commitAndContinue(); |
| 196 | + |
| 197 | + private: |
| 198 | + // Private members |
| 199 | + MTLDevice_t device_; |
| 200 | + MTLCommandQueue_t commandQueue_; |
| 201 | + MPSCommandBuffer_t commandBuffer_; |
| 202 | + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern |
| 203 | + MTLComputeCommandEncoder_t commandEncoder_; |
| 204 | + dispatch_queue_t serialQueue_; // For thread safety |
| 205 | + |
| 206 | + // Configuration |
| 207 | + bool enableCommitAndContinue_; |
| 208 | + |
| 209 | + // Singleton instance |
| 210 | + static ETMetalStream* defaultStream_; |
| 211 | +}; |
| 212 | + |
| 213 | +// ======================= |
| 214 | +// Global storage management functions |
| 215 | +// ======================= |
| 216 | +void storeFunctionHandle( |
| 217 | + ETMetalKernelFunction* raw_function, |
| 218 | + std::shared_ptr<ETMetalKernelFunction> function_shared_ptr); |
| 219 | +void storeLibraryHandle( |
| 220 | + ETMetalShaderLibrary* raw_library, |
| 221 | + std::unique_ptr<ETMetalShaderLibrary> library); |
| 222 | +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); |
| 223 | +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); |
| 224 | + |
| 225 | +// ======================= |
| 226 | +// Global stream access functions |
| 227 | +// ======================= |
| 228 | +ETMetalStream* getCurrentMetalStream(); |
| 229 | +void setCurrentMetalStream(ETMetalStream* stream); |
| 230 | + |
| 231 | +// ======================= |
| 232 | +// Metal stream synchronization functions (C++ interface with exceptions) |
| 233 | +// ======================= |
| 234 | +void synchronize_metal_stream(); |
| 235 | +void synchronize_metal_stream_with_type(int sync_type); |
| 236 | + |
| 237 | +// ======================= |
| 238 | +// Metal helper functions (C interface) |
| 239 | +// ======================= |
| 240 | +#ifdef __cplusplus |
| 241 | +extern "C" { |
| 242 | +#endif |
| 243 | + |
| 244 | +// Memory management functions for Metal |
| 245 | +void* metal_allocate_buffer(long bytes); |
| 246 | +bool metal_is_device_pointer(void* ptr); |
| 247 | +int metal_copy_memory( |
| 248 | + void* dst, |
| 249 | + const void* src, |
| 250 | + size_t nbytes, |
| 251 | + bool src_is_device, |
| 252 | + bool dst_is_device); |
| 253 | +void metal_cleanup_resources(); |
| 254 | + |
| 255 | +// Helper functions to access Metal objects |
| 256 | +MTLDevice_t get_metal_device(); |
| 257 | +MTLCommandQueue_t get_metal_command_queue(); |
| 258 | + |
| 259 | +#ifdef __cplusplus |
| 260 | +} |
| 261 | + |
| 262 | +// C++ only - expose the Metal buffer mapping |
| 263 | +#ifdef __OBJC__ |
| 264 | +extern std::unordered_map<void*, MTLBuffer_t> ptr_to_mtl_buffer; |
| 265 | +#endif |
| 266 | + |
| 267 | +#endif |
| 268 | + |
| 269 | +} // namespace metal |
| 270 | +} // namespace backends |
| 271 | +} // namespace executorch |
0 commit comments