Skip to content

Commit 930f6b9

Browse files
Update
[ghstack-poisoned]
2 parents cf93ffd + 71f87b6 commit 930f6b9

File tree

2 files changed

+127
-2
lines changed

2 files changed

+127
-2
lines changed

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,35 @@ enum class SyncType {
7777
// =======================
7878
// ETMetalShaderLibrary - ExecuTorch Metal shader library management
7979
// =======================
80+
81+
/**
82+
* @class ETMetalShaderLibrary
83+
* @brief Manages Metal shader library compilation and kernel function retrieval.
84+
*
85+
* This class provides a high-level interface for compiling Metal shading language
86+
* source code into a Metal library and creating compute pipeline states for
87+
* kernel functions. It handles the creation and caching of Metal compute pipeline
88+
* states and functions, which should be reused across multiple kernel dispatches.
89+
*
90+
* The class automatically compiles the provided shader source code upon construction
91+
* and maintains an internal cache of compute pipeline states for different kernel
92+
* functions to avoid redundant compilation.
93+
*
94+
* Example usage:
95+
* @code
96+
* std::string shaderSource = R"(
97+
* #include <metal_stdlib>
98+
* using namespace metal;
99+
* kernel void my_kernel(device float* data [[buffer(0)]],
100+
* uint tid [[thread_position_in_grid]]) {
101+
* data[tid] = data[tid] * 2.0;
102+
* }
103+
* )";
104+
*
105+
* ETMetalShaderLibrary library(shaderSource);
106+
* auto kernelFunction = library.getKernelFunction("my_kernel");
107+
* @endcode
108+
*/
80109
class ETMetalShaderLibrary {
81110
public:
82111
ETMetalShaderLibrary(const std::string& source);
@@ -103,6 +132,45 @@ class ETMetalShaderLibrary {
103132
// =======================
104133
// ETMetalKernelFunction - ExecuTorch Metal kernel function execution
105134
// =======================
135+
136+
/**
137+
* @class ETMetalKernelFunction
138+
* @brief Represents a Metal compute kernel function ready for execution.
139+
*
140+
* This class encapsulates a Metal compute pipeline state and function, providing
141+
* a high-level interface for setting kernel arguments and dispatching compute
142+
* work to the GPU. It handles the encoding of compute commands and manages the
143+
* interaction with Metal's compute command encoder.
144+
*
145+
* The class supports different dispatch patterns:
146+
* - Single-dimension dispatch for linear workloads
147+
* - Multi-dimensional dispatch for grid-based workloads
148+
* - Custom thread group sizes for performance optimization
149+
*
150+
* Kernel arguments can be set using tensors (which will be mapped to Metal buffers)
151+
* or scalar values. The class handles the encoding of these arguments
152+
* into the compute command encoder.
153+
*
154+
* Example usage:
155+
* @code
156+
* // Get kernel function from library
157+
* auto kernelFunction = library.getKernelFunction("vector_add");
158+
*
159+
* // Start encoding commands
160+
* kernelFunction->startEncoding();
161+
*
162+
* // Set tensor arguments
163+
* kernelFunction->setArg(0, inputTensorA);
164+
* kernelFunction->setArg(1, inputTensorB);
165+
* kernelFunction->setArg(2, outputTensor);
166+
*
167+
* // Set scalar argument
168+
* kernelFunction->setArg(3, static_cast<int64_t>(numElements));
169+
*
170+
* // Dispatch for linear workload
171+
* kernelFunction->dispatchSingle(numElements);
172+
* @endcode
173+
*/
106174
class ETMetalKernelFunction {
107175
public:
108176
ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func);
@@ -132,6 +200,45 @@ class ETMetalKernelFunction {
132200
// =======================
133201
// ETMetalStream - Metal command buffer and synchronization management
134202
// =======================
203+
204+
/**
205+
* @class ETMetalStream
206+
* @brief Manages Metal compute command streams and provides GPU synchronization.
207+
*
208+
* This class serves as the central management hub for Metal GPU operations, providing
209+
* a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle,
210+
* compute command encoder management, and various synchronization patterns required for
211+
* efficient GPU computation.
212+
*
213+
* Key features:
214+
* - Lazy command buffer and encoder creation for optimal resource usage
215+
* - Thread-safe operations using serial dispatch queues
216+
* - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE)
217+
* - Kernel coalescing to batch multiple operations efficiently
218+
* - MPSGraph integration for high-level neural network operations
219+
* - Memory operations (copy, fill) with GPU acceleration via blit encoders
220+
*
221+
* The stream follows PyTorch's MPS stream design patterns, providing similar semantics
222+
* for command buffer management and synchronization.
223+
*
224+
* Example usage:
225+
* @code
226+
* // Get current stream (typically the default stream)
227+
* ETMetalStream* stream = getCurrentMetalStream();
228+
*
229+
* // Execute kernel operations (handled automatically)
230+
* auto kernelFunction = library.getKernelFunction("my_kernel");
231+
* kernelFunction->startEncoding();
232+
* kernelFunction->setArg(0, inputTensor);
233+
* kernelFunction->dispatchSingle(numElements);
234+
*
235+
* // Synchronize to ensure completion
236+
* stream->synchronize(SyncType::COMMIT_AND_WAIT);
237+
*
238+
* // Copy between GPU buffers using blit encoder
239+
* stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT);
240+
* @endcode
241+
*/
135242
class ETMetalStream {
136243
public:
137244
ETMetalStream();

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,26 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
743743

744744
void ETMetalStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
745745
size_t srcOffset, size_t dstOffset, SyncType syncType) {
746+
747+
if (length == 0) {
748+
return;
749+
750+
// Check that offsets are within buffer bounds before copying
751+
if (!srcBuffer || !dstBuffer) {
752+
ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil");
753+
return;
754+
}
755+
NSUInteger srcBufferLength = [srcBuffer length];
756+
NSUInteger dstBufferLength = [dstBuffer length];
757+
if (srcOffset + length > srcBufferLength) {
758+
ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength);
759+
return;
760+
}
761+
if (dstOffset + length > dstBufferLength) {
762+
ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength);
763+
return;
764+
}
765+
746766
dispatch_sync(serialQueue_, ^{
747767
@autoreleasepool {
748768
endKernelCoalescing();
@@ -792,8 +812,6 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
792812
targetOperations:nil
793813
resultsDictionary:results
794814
executionDescriptor:nil];
795-
796-
//synchronize(syncType);
797815
}
798816
});
799817
}

0 commit comments

Comments
 (0)