Skip to content

Commit e793135

Browse files
Metal backend: SDPA metal implementation (pytorch#16086)
Replaces SDPA MPSGraph's implementation with Metal implementation (adapted from MLX implementation, with several modifications, to support transposed middle dimensions, and floating point attention masks). Speeds up voxtral/whisper by 2-3x Fixes BFloat16 issue on macOS 26.1
1 parent 62dccf1 commit e793135

File tree

3 files changed

+669
-536
lines changed

3 files changed

+669
-536
lines changed

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ class ETMetalKernelFunction {
181181
void startEncoding();
182182
void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor);
183183
void setArg(unsigned idx, int64_t val);
184+
void setArg(unsigned idx, uint32_t val);
185+
void setArg(unsigned idx, float val);
186+
void setArg(unsigned idx, bool val);
187+
void setArg(unsigned idx, const void* data, size_t size);
188+
189+
// Helper for Metal uint3 struct
190+
void setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z);
184191

185192
void dispatchSingle(uint64_t length);
186193
void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size);
@@ -191,6 +198,15 @@ class ETMetalKernelFunction {
191198
const uint64_t* group_size,
192199
size_t group_size_size);
193200

201+
// Dispatch with explicit threadgroup count (not thread count)
202+
void dispatchThreadgroups(
203+
uint64_t gridX,
204+
uint64_t gridY,
205+
uint64_t gridZ,
206+
uint64_t threadsX,
207+
uint64_t threadsY,
208+
uint64_t threadsZ);
209+
194210
void runCommandBlock(std::function<void(void)> f);
195211

196212
private:

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
1111
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
1212
#import <Foundation/Foundation.h>
13+
#include <simd/simd.h>
1314
#include <executorch/runtime/platform/log.h>
1415
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1516
#include <executorch/backends/apple/metal/runtime/shims/et_metal.h>
@@ -377,6 +378,58 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
377378
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx);
378379
}
379380

381+
void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) {
382+
if (!encoder_) {
383+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
384+
return;
385+
}
386+
387+
[encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx];
388+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx);
389+
}
390+
391+
void ETMetalKernelFunction::setArg(unsigned idx, float val) {
392+
if (!encoder_) {
393+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
394+
return;
395+
}
396+
397+
[encoder_ setBytes:&val length:sizeof(float) atIndex:idx];
398+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx);
399+
}
400+
401+
void ETMetalKernelFunction::setArg(unsigned idx, bool val) {
402+
if (!encoder_) {
403+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
404+
return;
405+
}
406+
407+
[encoder_ setBytes:&val length:sizeof(bool) atIndex:idx];
408+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx);
409+
}
410+
411+
void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) {
412+
if (!encoder_) {
413+
ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder");
414+
return;
415+
}
416+
417+
[encoder_ setBytes:data length:size atIndex:idx];
418+
ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size);
419+
}
420+
421+
void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) {
422+
if (!encoder_) {
423+
ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder");
424+
return;
425+
}
426+
427+
// Use SIMD library's uint3 type which matches Metal shader's uint3 layout
428+
simd_uint3 val = {x, y, z};
429+
[encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx];
430+
ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx);
431+
}
432+
380433
void ETMetalKernelFunction::dispatchSingle(uint64_t length) {
381434
if (!encoder_) {
382435
ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder");
@@ -502,6 +555,40 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
502555

503556
}
504557

558+
void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ,
559+
uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) {
560+
if (!encoder_) {
561+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder");
562+
return;
563+
}
564+
565+
if (!cps_) {
566+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state");
567+
return;
568+
}
569+
570+
// Calculate total threads per threadgroup
571+
uint64_t totalThreads = threadsX * threadsY * threadsZ;
572+
573+
const auto maxThreadsPerGroup = static_cast<uint64_t>([cps_ maxTotalThreadsPerThreadgroup]);
574+
575+
// Validate total thread count
576+
if (totalThreads > maxThreadsPerGroup) {
577+
ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu",
578+
(unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup);
579+
return;
580+
}
581+
582+
MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ);
583+
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ);
584+
585+
[encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
586+
587+
ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]",
588+
(unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ,
589+
(unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ);
590+
}
591+
505592
void ETMetalKernelFunction::runCommandBlock(std::function<void(void)> f) {
506593
// Use dispatch_sync with the stream's serial queue for thread safety and synchronization
507594
// This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...)

0 commit comments

Comments
 (0)