Skip to content

Commit b8df14b

Browse files
committed
Update tokenizer path in README for parakeet model (#16495)
1 parent 8e8d97e commit b8df14b

File tree

6 files changed

+420
-5
lines changed

6 files changed

+420
-5
lines changed

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_device_name(cls) -> str:
3232
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3333
return {
3434
"aoti_torch_mps_addmm_out": None,
35+
"aoti_torch_mps_bmm_out": None,
3536
"aoti_torch_mps_convolution": None,
3637
"aoti_torch_mps_mm_out": None,
3738
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,

backends/apple/metal/runtime/shims/et_metal_ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ AOTITorchError aoti_torch_mps_mm_out(
2727
AOTITensorHandle self,
2828
AOTITensorHandle mat2);
2929

30+
/**
31+
* ExecutorTorch implementation of aoti_torch_mps_bmm_out.
32+
* Performs batched matrix multiplication: out = self @ mat2
33+
* All tensors must be 3-D with matching batch dimensions.
34+
*/
35+
AOTITorchError aoti_torch_mps_bmm_out(
36+
AOTITensorHandle out,
37+
AOTITensorHandle self,
38+
AOTITensorHandle mat2);
39+
3040
/**
3141
* ExecutorTorch implementation of aoti_torch_mps_convolution.
3242
* Performs 2D convolution operation - matches PyTorch AOTI signature

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,345 @@ AOTITorchError aoti_torch_mps_mm_out(
626626
}
627627
}
628628

629+
AOTITorchError aoti_torch_mps_bmm_out(
630+
AOTITensorHandle out,
631+
AOTITensorHandle self,
632+
AOTITensorHandle mat2) {
633+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Starting with out=%p, self=%p, mat2=%p",
634+
out, self, mat2);
635+
636+
// Validate non-null handles
637+
if (!out || !self || !mat2) {
638+
ET_LOG(Error, "aoti_torch_mps_bmm_out: null tensor handles");
639+
return Error::InvalidArgument;
640+
}
641+
642+
@autoreleasepool {
643+
try {
644+
// Convert AOTITensorHandle to ExecutorTorch tensors
645+
auto out_tensor = reinterpret_cast<Tensor*>(out);
646+
auto self_tensor = reinterpret_cast<Tensor*>(self);
647+
auto mat2_tensor = reinterpret_cast<Tensor*>(mat2);
648+
649+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Converted tensor handles to ET tensors");
650+
651+
// Validate tensor dimensions - bmm requires 3-D tensors
652+
if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) {
653+
ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. "
654+
"Got self.dim=%zd (shape=[%d,%d,%d]), "
655+
"mat2.dim=%zd (shape=[%d,%d,%d]), "
656+
"out.dim=%zd (shape=[%d,%d,%d])",
657+
self_tensor->dim(),
658+
self_tensor->dim() > 0 ? (int)self_tensor->sizes()[0] : 0,
659+
self_tensor->dim() > 1 ? (int)self_tensor->sizes()[1] : 0,
660+
self_tensor->dim() > 2 ? (int)self_tensor->sizes()[2] : 0,
661+
mat2_tensor->dim(),
662+
mat2_tensor->dim() > 0 ? (int)mat2_tensor->sizes()[0] : 0,
663+
mat2_tensor->dim() > 1 ? (int)mat2_tensor->sizes()[1] : 0,
664+
mat2_tensor->dim() > 2 ? (int)mat2_tensor->sizes()[2] : 0,
665+
out_tensor->dim(),
666+
out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0,
667+
out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0,
668+
out_tensor->dim() > 2 ? (int)out_tensor->sizes()[2] : 0);
669+
return Error::InvalidArgument;
670+
}
671+
672+
int64_t B = self_tensor->sizes()[0]; // batch size
673+
int64_t M = self_tensor->sizes()[1]; // rows of self
674+
int64_t K = self_tensor->sizes()[2]; // cols of self / rows of mat2
675+
int64_t N = mat2_tensor->sizes()[2]; // cols of mat2
676+
677+
// Validate shape constraints
678+
// self: [B, M, K], mat2: [B, K, N], out: [B, M, N]
679+
if (mat2_tensor->sizes()[0] != B) {
680+
ET_LOG(Error, "aoti_torch_mps_bmm_out: batch size mismatch. "
681+
"Expected mat2[0]=%d to match self[0]=%lld. "
682+
"self.shape=[%lld,%lld,%lld], mat2.shape=[%d,%d,%d]",
683+
(int)mat2_tensor->sizes()[0], (long long)B,
684+
(long long)B, (long long)M, (long long)K,
685+
(int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]);
686+
return Error::InvalidArgument;
687+
}
688+
689+
if (mat2_tensor->sizes()[1] != K) {
690+
ET_LOG(Error, "aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. "
691+
"Expected mat2[1]=%d to match self[2]=%lld. "
692+
"Cannot multiply [%lld,%lld,%lld] @ [%d,%d,%d]",
693+
(int)mat2_tensor->sizes()[1], (long long)K,
694+
(long long)B, (long long)M, (long long)K,
695+
(int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]);
696+
return Error::InvalidArgument;
697+
}
698+
699+
if (out_tensor->sizes()[0] != B || out_tensor->sizes()[1] != M || out_tensor->sizes()[2] != N) {
700+
ET_LOG(Error, "aoti_torch_mps_bmm_out: output shape mismatch. "
701+
"Expected out.shape=[%lld,%lld,%lld], got [%d,%d,%d]",
702+
(long long)B, (long long)M, (long long)N,
703+
(int)out_tensor->sizes()[0], (int)out_tensor->sizes()[1], (int)out_tensor->sizes()[2]);
704+
return Error::InvalidArgument;
705+
}
706+
707+
// Validate dtype consistency
708+
int32_t self_dtype = static_cast<int32_t>(self_tensor->scalar_type());
709+
int32_t mat2_dtype = static_cast<int32_t>(mat2_tensor->scalar_type());
710+
int32_t out_dtype = static_cast<int32_t>(out_tensor->scalar_type());
711+
712+
if (self_dtype != mat2_dtype || self_dtype != out_dtype) {
713+
ET_LOG(Error, "aoti_torch_mps_bmm_out: dtype mismatch. "
714+
"All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d",
715+
self_dtype, mat2_dtype, out_dtype);
716+
return Error::InvalidArgument;
717+
}
718+
719+
int32_t dtype = self_dtype;
720+
721+
// Validate layout: BMM requires strictly contiguous 3D tensors
722+
// For shape [B, M, K], contiguous strides MUST be [M*K, K, 1]
723+
//
724+
// Why strict contiguity is required:
725+
// - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer
726+
// as containing dense row-major data for the given shape
727+
// - Non-contiguous layouts (transposed, views with strides, etc.) have different
728+
// memory layouts that don't match what MPS expects
729+
// - This would result in SILENT WRONG RESULTS
730+
// - This is an _out op: we must NOT create implicit copies
731+
// - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported)
732+
//
733+
// Limitation: This implementation does not explicitly check storage offset (no API available).
734+
// Tensors with non-zero storage offsets are not explicitly rejected but may work if they
735+
// happen to have contiguous strides. Users should ensure tensors are base tensors without offsets.
736+
auto self_strides = self_tensor->strides();
737+
auto mat2_strides = mat2_tensor->strides();
738+
auto out_strides = out_tensor->strides();
739+
740+
// Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1]
741+
if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) {
742+
ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. "
743+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
744+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].",
745+
(long long)(M * K), (long long)K, (long long)B, (long long)M, (long long)K,
746+
self_strides[0], self_strides[1], self_strides[2]);
747+
return Error::InvalidArgument;
748+
}
749+
750+
// Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1]
751+
if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) {
752+
ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. "
753+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
754+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].",
755+
(long long)(K * N), (long long)N, (long long)B, (long long)K, (long long)N,
756+
mat2_strides[0], mat2_strides[1], mat2_strides[2]);
757+
return Error::InvalidArgument;
758+
}
759+
760+
// Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1]
761+
if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) {
762+
ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. "
763+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
764+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].",
765+
(long long)(M * N), (long long)N, (long long)B, (long long)M, (long long)N,
766+
out_strides[0], out_strides[1], out_strides[2]);
767+
return Error::InvalidArgument;
768+
}
769+
770+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Validated shapes and strides. "
771+
"batch=%lld, self=[%lld,%lld], mat2=[%lld,%lld], out=[%lld,%lld]",
772+
B, M, K, K, N, M, N);
773+
774+
// Get Metal stream and device
775+
ETMetalStream* stream = getCurrentMetalStream();
776+
if (!stream) {
777+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get current Metal stream");
778+
return Error::Internal;
779+
}
780+
781+
id<MTLDevice> device = get_metal_device();
782+
if (!device) {
783+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device");
784+
return Error::Internal;
785+
}
786+
(void)device; // Used for validation, consistent with other ops
787+
788+
// Get Metal buffers for input and output tensors
789+
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self");
790+
id<MTLBuffer> mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_bmm_out", "mat2");
791+
id<MTLBuffer> out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_bmm_out", "out");
792+
793+
// Validate buffers are non-null
794+
if (!self_buffer || !mat2_buffer || !out_buffer) {
795+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal buffers. "
796+
"self_buffer=%p, mat2_buffer=%p, out_buffer=%p",
797+
self_buffer, mat2_buffer, out_buffer);
798+
return Error::Internal;
799+
}
800+
801+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Using Metal buffers - self=%p, mat2=%p, out=%p",
802+
self_buffer, mat2_buffer, out_buffer);
803+
804+
// End any existing kernel coalescing to ensure clean state
805+
// (consistent with mm_out and conv pattern)
806+
stream->endKernelCoalescing();
807+
808+
// Map dtype to MPS type and validate support
809+
// Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h)
810+
// FLOAT16 is not in SupportedDTypes enum and is not supported
811+
MPSDataType mps_dtype;
812+
813+
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
814+
mps_dtype = MPSDataTypeFloat32;
815+
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
816+
mps_dtype = MPSDataTypeBFloat16;
817+
} else {
818+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Unsupported data type: %d. "
819+
"Supported types: FLOAT32 (%d), BFLOAT16 (%d)",
820+
dtype,
821+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
822+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
823+
return Error::InvalidArgument;
824+
}
825+
826+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: dtype=%d, mps_dtype=%d",
827+
dtype, (int)mps_dtype);
828+
829+
// Define shapes for graph placeholders and tensor data
830+
NSArray<NSNumber*>* selfShape = @[@(B), @(M), @(K)];
831+
NSArray<NSNumber*>* mat2Shape = @[@(B), @(K), @(N)];
832+
NSArray<NSNumber*>* outShape = @[@(B), @(M), @(N)];
833+
834+
// Create cache key for this batched matrix multiplication
835+
// Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag
836+
// This allows reuse when same BMM shape/dtype is called repeatedly
837+
GraphCacheKey cache_key;
838+
cache_key.op_name = "bmm";
839+
cache_key.shape_params = {B, M, K, N};
840+
cache_key.dtype = dtype;
841+
cache_key.transpose_flag = false; // BMM has no transpose handling
842+
843+
// Check if we have a cached graph
844+
MPSGraph* mpsGraph = nullptr;
845+
MPSGraphTensor* outputTensor = nil;
846+
MPSGraphTensor* selfPlaceholder = nil;
847+
MPSGraphTensor* mat2Placeholder = nil;
848+
849+
auto cache_it = graph_cache.find(cache_key);
850+
if (cache_it != graph_cache.end()) {
851+
// Cache hit - reuse compiled graph and tensor references
852+
CachedGraph& cached = cache_it->second;
853+
mpsGraph = cached.graph;
854+
selfPlaceholder = cached.input1;
855+
mat2Placeholder = cached.input2;
856+
outputTensor = cached.output;
857+
858+
cache_stats.hits++;
859+
cache_stats.logStats();
860+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits);
861+
862+
} else {
863+
// Cache miss - create and compile new graph
864+
mpsGraph = [MPSGraph new];
865+
cache_stats.misses++;
866+
cache_stats.logStats();
867+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses);
868+
869+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Creating placeholders with shapes self:[%lld,%lld,%lld] mat2:[%lld,%lld,%lld]",
870+
B, M, K, B, K, N);
871+
872+
// Create 3D placeholders for batched matrices
873+
// These represent the logical shapes for the batched matrix multiplication
874+
selfPlaceholder = [mpsGraph placeholderWithShape:selfShape
875+
dataType:mps_dtype
876+
name:@"self"];
877+
mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape
878+
dataType:mps_dtype
879+
name:@"mat2"];
880+
881+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Created input placeholders");
882+
883+
// MPSGraph matrixMultiplication handles batched case natively when given 3D tensors
884+
// For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N]
885+
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder
886+
secondaryTensor:mat2Placeholder
887+
name:@"bmm_result"];
888+
889+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Successfully created batched matrix multiplication tensor");
890+
891+
// Cache the compiled graph and tensor references for reuse
892+
CachedGraph cached_graph;
893+
cached_graph.graph = mpsGraph;
894+
cached_graph.input1 = selfPlaceholder;
895+
cached_graph.input2 = mat2Placeholder;
896+
cached_graph.input3 = nil; // No third input for BMM
897+
cached_graph.output = outputTensor;
898+
graph_cache[cache_key] = cached_graph;
899+
900+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Cached compiled MPSGraph for future reuse");
901+
} // End of cache miss/hit block
902+
903+
// Create feeds dictionary for graph execution
904+
NSMutableDictionary* feeds = [NSMutableDictionary dictionary];
905+
906+
// Create MPSGraphTensorData objects for input tensors
907+
// These wrap the MTLBuffers with the shape information
908+
// Initialize to nil for safe cleanup in exception path
909+
MPSGraphTensorData* selfData = nil;
910+
MPSGraphTensorData* mat2Data = nil;
911+
MPSGraphTensorData* outputData = nil;
912+
913+
selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer
914+
shape:selfShape
915+
dataType:mps_dtype];
916+
mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer
917+
shape:mat2Shape
918+
dataType:mps_dtype];
919+
920+
feeds[selfPlaceholder] = selfData;
921+
feeds[mat2Placeholder] = mat2Data;
922+
923+
// Create output tensor data
924+
outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer
925+
shape:outShape
926+
dataType:mps_dtype];
927+
928+
// Build results dictionary
929+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
930+
outputTensor: outputData
931+
};
932+
933+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Executing MPSGraph");
934+
935+
// Execute the batched matrix multiplication
936+
@try {
937+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
938+
} @catch (NSException *exception) {
939+
ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s",
940+
[[exception name] UTF8String], [[exception reason] UTF8String]);
941+
// Guard releases against nil
942+
if (selfData) [selfData release];
943+
if (mat2Data) [mat2Data release];
944+
if (outputData) [outputData release];
945+
return Error::Internal;
946+
}
947+
948+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: MPSGraph execution completed successfully");
949+
950+
// Release MPSGraphTensorData objects
951+
[selfData release];
952+
[mat2Data release];
953+
[outputData release];
954+
955+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Executed successfully for batch size %lld", B);
956+
return Error::Ok;
957+
958+
} catch (const std::exception& e) {
959+
ET_LOG(Error, "aoti_torch_mps_bmm_out exception: %s", e.what());
960+
return Error::Internal;
961+
} catch (...) {
962+
ET_LOG(Error, "aoti_torch_mps_bmm_out: unknown exception");
963+
return Error::Internal;
964+
}
965+
}
966+
}
967+
629968
AOTITorchError aoti_torch_mps_convolution(
630969
AOTITensorHandle input,
631970
AOTITensorHandle weight,

examples/models/parakeet/CMakePresets.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
"displayName": "Parakeet runner (Metal)",
3535
"inherits": ["parakeet-base"],
3636
"cacheVariables": {
37-
"EXECUTORCH_BUILD_METAL": "ON"
37+
"CMAKE_BUILD_TYPE": "Debug",
38+
"EXECUTORCH_BUILD_METAL": "ON",
39+
"EXECUTORCH_ENABLE_LOGGING": "ON",
40+
"ET_MIN_LOG_LEVEL": "Info"
3841
},
3942
"condition": {
4043
"lhs": "${hostSystemName}",

0 commit comments

Comments
 (0)