Skip to content

Commit 7e11615

Browse files
Update
[ghstack-poisoned]
1 parent d6f0bc9 commit 7e11615

File tree

2 files changed

+1143
-0
lines changed

2 files changed

+1143
-0
lines changed
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef __OBJC__
12+
#import <Foundation/Foundation.h>
13+
#import <Metal/Metal.h>
14+
#include <dispatch/dispatch.h>
15+
// Forward declarations for MetalPerformanceShadersGraph types
16+
@class MPSGraph;
17+
@class MPSCommandBuffer;
18+
// Metal type definitions for Objective-C compilation
19+
typedef id<MTLDevice> MTLDevice_t;
20+
typedef id<MTLCommandQueue> MTLCommandQueue_t;
21+
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
22+
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
23+
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
24+
typedef id<MTLFunction> MTLFunction_t;
25+
typedef id<MTLLibrary> MTLLibrary_t;
26+
typedef id<MTLBuffer> MTLBuffer_t;
27+
typedef dispatch_queue_t dispatch_queue_t;
28+
typedef MPSGraph* MPSGraph_t;
29+
typedef MPSCommandBuffer* MPSCommandBuffer_t;
30+
typedef NSDictionary* NSDictionary_t;
31+
#else
32+
// Forward declarations for C++ compilation
33+
typedef void* MTLDevice_t;
34+
typedef void* MTLCommandQueue_t;
35+
typedef void* MTLCommandBuffer_t;
36+
typedef void* MTLComputeCommandEncoder_t;
37+
typedef void* MTLComputePipelineState_t;
38+
typedef void* MTLFunction_t;
39+
typedef void* MTLLibrary_t;
40+
typedef void* MTLBuffer_t;
41+
typedef void* dispatch_queue_t;
42+
typedef void* MPSGraph_t;
43+
typedef void* MPSCommandBuffer_t;
44+
typedef void* NSDictionary_t;
45+
#endif
46+
47+
#include <functional>
48+
#include <memory>
49+
#include <string>
50+
#include <unordered_map>
51+
#include <vector>
52+
53+
namespace executorch::runtime::etensor {
54+
class Tensor;
55+
}
56+
57+
namespace executorch {
58+
namespace backends {
59+
namespace metal {
60+
61+
// Forward declarations
62+
class ETMetalKernelFunction;
63+
class ETMetalStream;
64+
65+
// =======================
66+
// SyncType - Metal synchronization options
67+
// =======================
68+
enum class SyncType {
69+
NONE, // no commit to command buffer
70+
COMMIT, // commit and flush the command buffer
71+
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
72+
COMMIT_AND_CONTINUE, // commit and continue with a new underlying command
73+
// buffer
74+
COMMIT_ADAPTIVE, // commit adaptively based on available memory
75+
};
76+
77+
// =======================
78+
// ETMetalShaderLibrary - ExecuTorch Metal shader library management
79+
// =======================
80+
class ETMetalShaderLibrary {
81+
public:
82+
ETMetalShaderLibrary(const std::string& source);
83+
~ETMetalShaderLibrary();
84+
85+
std::shared_ptr<ETMetalKernelFunction> getKernelFunction(
86+
const std::string& name);
87+
88+
private:
89+
void compileLibrary();
90+
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
91+
const std::string& functionName);
92+
93+
friend class ETMetalKernelFunction;
94+
95+
std::string shaderSource_;
96+
MTLLibrary_t library_;
97+
std::unordered_map<
98+
std::string,
99+
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
100+
pipelineStates_;
101+
};
102+
103+
// =======================
104+
// ETMetalKernelFunction - ExecuTorch Metal kernel function execution
105+
// =======================
106+
class ETMetalKernelFunction {
107+
public:
108+
ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func);
109+
~ETMetalKernelFunction();
110+
111+
void startEncoding();
112+
void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor);
113+
void setArg(unsigned idx, int64_t val);
114+
115+
void dispatchSingle(uint64_t length);
116+
void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size);
117+
void dispatchArray(const uint64_t* length, size_t length_size);
118+
void dispatchArrayWithGroupSize(
119+
const uint64_t* length,
120+
size_t length_size,
121+
const uint64_t* group_size,
122+
size_t group_size_size);
123+
124+
void runCommandBlock(std::function<void(void)> f);
125+
126+
private:
127+
MTLComputePipelineState_t cps_;
128+
MTLFunction_t func_;
129+
MTLComputeCommandEncoder_t encoder_;
130+
};
131+
132+
// =======================
133+
// ETMetalStream - Metal command buffer and synchronization management
134+
// =======================
135+
class ETMetalStream {
136+
public:
137+
ETMetalStream();
138+
~ETMetalStream();
139+
140+
// Get the default stream (singleton)
141+
static ETMetalStream* getDefaultStream();
142+
143+
// Device and queue access
144+
MTLDevice_t device() const {
145+
return device_;
146+
}
147+
MTLCommandQueue_t commandQueue() const {
148+
return commandQueue_;
149+
}
150+
dispatch_queue_t queue() const {
151+
return serialQueue_;
152+
}
153+
154+
// Synchronization methods
155+
void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT);
156+
void synchronize(); // Overload for backward compatibility
157+
bool isEmpty() const;
158+
159+
// Command buffer management with lazy creation
160+
MPSCommandBuffer_t commandBuffer();
161+
MTLComputeCommandEncoder_t commandEncoder();
162+
163+
void endKernelCoalescing();
164+
165+
// MPSGraph execution
166+
void executeMPSGraph(
167+
MPSGraph_t mpsGraph,
168+
NSDictionary_t feeds,
169+
NSDictionary_t results,
170+
SyncType syncType = SyncType::COMMIT_ADAPTIVE);
171+
172+
// Command buffer lifecycle management
173+
void commitCommandBuffer(MTLCommandBuffer_t commandBuffer);
174+
void flush();
175+
176+
// Memory operations
177+
void fill(
178+
MTLBuffer_t buffer,
179+
uint8_t value,
180+
size_t length,
181+
size_t offset,
182+
SyncType syncType = SyncType::NONE);
183+
void copy(
184+
MTLBuffer_t srcBuffer,
185+
MTLBuffer_t dstBuffer,
186+
size_t length,
187+
size_t srcOffset,
188+
size_t dstOffset,
189+
SyncType syncType = SyncType::NONE);
190+
191+
private:
192+
// Private synchronization methods
193+
void commit();
194+
void commitAndWait();
195+
void commitAndContinue();
196+
197+
private:
198+
// Private members
199+
MTLDevice_t device_;
200+
MTLCommandQueue_t commandQueue_;
201+
MPSCommandBuffer_t commandBuffer_;
202+
MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern
203+
MTLComputeCommandEncoder_t commandEncoder_;
204+
dispatch_queue_t serialQueue_; // For thread safety
205+
206+
// Configuration
207+
bool enableCommitAndContinue_;
208+
209+
// Singleton instance
210+
static ETMetalStream* defaultStream_;
211+
};
212+
213+
// =======================
214+
// Global storage management functions
215+
// =======================
216+
void storeFunctionHandle(
217+
ETMetalKernelFunction* raw_function,
218+
std::shared_ptr<ETMetalKernelFunction> function_shared_ptr);
219+
void storeLibraryHandle(
220+
ETMetalShaderLibrary* raw_library,
221+
std::unique_ptr<ETMetalShaderLibrary> library);
222+
bool removeFunctionHandle(ETMetalKernelFunction* raw_function);
223+
bool removeLibraryHandle(ETMetalShaderLibrary* raw_library);
224+
225+
// =======================
226+
// Global stream access functions
227+
// =======================
228+
ETMetalStream* getCurrentMetalStream();
229+
void setCurrentMetalStream(ETMetalStream* stream);
230+
231+
// =======================
232+
// Metal stream synchronization functions (C++ interface with exceptions)
233+
// =======================
234+
void synchronize_metal_stream();
235+
void synchronize_metal_stream_with_type(int sync_type);
236+
237+
// =======================
238+
// Metal helper functions (C interface)
239+
// =======================
240+
#ifdef __cplusplus
241+
extern "C" {
242+
#endif
243+
244+
// Memory management functions for Metal
245+
void* metal_allocate_buffer(long bytes);
246+
bool metal_is_device_pointer(void* ptr);
247+
int metal_copy_memory(
248+
void* dst,
249+
const void* src,
250+
size_t nbytes,
251+
bool src_is_device,
252+
bool dst_is_device);
253+
void metal_cleanup_resources();
254+
255+
// Helper functions to access Metal objects
256+
MTLDevice_t get_metal_device();
257+
MTLCommandQueue_t get_metal_command_queue();
258+
259+
#ifdef __cplusplus
260+
}
261+
262+
// C++ only - expose the Metal buffer mapping
263+
#ifdef __OBJC__
264+
extern std::unordered_map<void*, MTLBuffer_t> ptr_to_mtl_buffer;
265+
#endif
266+
267+
#endif
268+
269+
} // namespace metal
270+
} // namespace backends
271+
} // namespace executorch

0 commit comments

Comments
 (0)