Skip to content

Commit bd48479

Browse files
committed
Update tokenizer path in README for parakeet model (#16495)
1 parent 8e8d97e commit bd48479

File tree

6 files changed

+408
-5
lines changed

6 files changed

+408
-5
lines changed

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_device_name(cls) -> str:
3232
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3333
return {
3434
"aoti_torch_mps_addmm_out": None,
35+
"aoti_torch_mps_bmm_out": None,
3536
"aoti_torch_mps_convolution": None,
3637
"aoti_torch_mps_mm_out": None,
3738
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,

backends/apple/metal/runtime/shims/et_metal_ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ AOTITorchError aoti_torch_mps_mm_out(
2727
AOTITensorHandle self,
2828
AOTITensorHandle mat2);
2929

30+
/**
31+
* ExecutorTorch implementation of aoti_torch_mps_bmm_out.
32+
* Performs batched matrix multiplication: out = self @ mat2
33+
* All tensors must be 3-D with matching batch dimensions.
34+
*/
35+
AOTITorchError aoti_torch_mps_bmm_out(
36+
AOTITensorHandle out,
37+
AOTITensorHandle self,
38+
AOTITensorHandle mat2);
39+
3040
/**
3141
* ExecutorTorch implementation of aoti_torch_mps_convolution.
3242
* Performs 2D convolution operation - matches PyTorch AOTI signature

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,342 @@ AOTITorchError aoti_torch_mps_mm_out(
626626
}
627627
}
628628

629+
AOTITorchError aoti_torch_mps_bmm_out(
630+
AOTITensorHandle out,
631+
AOTITensorHandle self,
632+
AOTITensorHandle mat2) {
633+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Starting with out=%p, self=%p, mat2=%p",
634+
out, self, mat2);
635+
636+
// Validate non-null handles
637+
if (!out || !self || !mat2) {
638+
ET_LOG(Error, "aoti_torch_mps_bmm_out: null tensor handles");
639+
return Error::InvalidArgument;
640+
}
641+
642+
@autoreleasepool {
643+
try {
644+
// Convert AOTITensorHandle to ExecutorTorch tensors
645+
auto out_tensor = reinterpret_cast<Tensor*>(out);
646+
auto self_tensor = reinterpret_cast<Tensor*>(self);
647+
auto mat2_tensor = reinterpret_cast<Tensor*>(mat2);
648+
649+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Converted tensor handles to ET tensors");
650+
651+
// Validate tensor dimensions - bmm requires 3-D tensors
652+
if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) {
653+
ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. "
654+
"Got self.dim=%d (shape=[%lld,%lld,%lld]), "
655+
"mat2.dim=%d (shape=[%lld,%lld,%lld]), "
656+
"out.dim=%d (shape=[%lld,%lld,%lld])",
657+
self_tensor->dim(),
658+
self_tensor->dim() > 0 ? self_tensor->sizes()[0] : 0,
659+
self_tensor->dim() > 1 ? self_tensor->sizes()[1] : 0,
660+
self_tensor->dim() > 2 ? self_tensor->sizes()[2] : 0,
661+
mat2_tensor->dim(),
662+
mat2_tensor->dim() > 0 ? mat2_tensor->sizes()[0] : 0,
663+
mat2_tensor->dim() > 1 ? mat2_tensor->sizes()[1] : 0,
664+
mat2_tensor->dim() > 2 ? mat2_tensor->sizes()[2] : 0,
665+
out_tensor->dim(),
666+
out_tensor->dim() > 0 ? out_tensor->sizes()[0] : 0,
667+
out_tensor->dim() > 1 ? out_tensor->sizes()[1] : 0,
668+
out_tensor->dim() > 2 ? out_tensor->sizes()[2] : 0);
669+
return Error::InvalidArgument;
670+
}
671+
672+
int64_t B = self_tensor->sizes()[0]; // batch size
673+
int64_t M = self_tensor->sizes()[1]; // rows of self
674+
int64_t K = self_tensor->sizes()[2]; // cols of self / rows of mat2
675+
int64_t N = mat2_tensor->sizes()[2]; // cols of mat2
676+
677+
// Validate shape constraints
678+
// self: [B, M, K], mat2: [B, K, N], out: [B, M, N]
679+
if (mat2_tensor->sizes()[0] != B) {
680+
ET_LOG(Error, "aoti_torch_mps_bmm_out: batch size mismatch. "
681+
"Expected mat2[0]=%lld to match self[0]=%lld. "
682+
"self.shape=[%lld,%lld,%lld], mat2.shape=[%lld,%lld,%lld]",
683+
mat2_tensor->sizes()[0], B,
684+
B, M, K,
685+
mat2_tensor->sizes()[0], mat2_tensor->sizes()[1], mat2_tensor->sizes()[2]);
686+
return Error::InvalidArgument;
687+
}
688+
689+
if (mat2_tensor->sizes()[1] != K) {
690+
ET_LOG(Error, "aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. "
691+
"Expected mat2[1]=%lld to match self[2]=%lld. "
692+
"Cannot multiply [%lld,%lld,%lld] @ [%lld,%lld,%lld]",
693+
mat2_tensor->sizes()[1], K,
694+
B, M, K,
695+
mat2_tensor->sizes()[0], mat2_tensor->sizes()[1], mat2_tensor->sizes()[2]);
696+
return Error::InvalidArgument;
697+
}
698+
699+
if (out_tensor->sizes()[0] != B || out_tensor->sizes()[1] != M || out_tensor->sizes()[2] != N) {
700+
ET_LOG(Error, "aoti_torch_mps_bmm_out: output shape mismatch. "
701+
"Expected out.shape=[%lld,%lld,%lld], got [%lld,%lld,%lld]",
702+
B, M, N,
703+
out_tensor->sizes()[0], out_tensor->sizes()[1], out_tensor->sizes()[2]);
704+
return Error::InvalidArgument;
705+
}
706+
707+
// Validate dtype consistency
708+
int32_t self_dtype = static_cast<int32_t>(self_tensor->scalar_type());
709+
int32_t mat2_dtype = static_cast<int32_t>(mat2_tensor->scalar_type());
710+
int32_t out_dtype = static_cast<int32_t>(out_tensor->scalar_type());
711+
712+
if (self_dtype != mat2_dtype || self_dtype != out_dtype) {
713+
ET_LOG(Error, "aoti_torch_mps_bmm_out: dtype mismatch. "
714+
"All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d",
715+
self_dtype, mat2_dtype, out_dtype);
716+
return Error::InvalidArgument;
717+
}
718+
719+
int32_t dtype = self_dtype;
720+
721+
// Validate layout: BMM requires strictly contiguous 3D tensors
722+
// For shape [B, M, K], contiguous strides MUST be [M*K, K, 1]
723+
//
724+
// Why strict contiguity is required:
725+
// - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer
726+
// as containing dense row-major data for the given shape
727+
// - Non-contiguous layouts (transposed, views with strides, etc.) have different
728+
// memory layouts that don't match what MPS expects
729+
// - This would result in SILENT WRONG RESULTS
730+
// - This is an _out op: we must NOT create implicit copies
731+
// - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported)
732+
//
733+
// Limitation: This implementation does not explicitly check storage offset (no API available).
734+
// Tensors with non-zero storage offsets are not explicitly rejected but may work if they
735+
// happen to have contiguous strides. Users should ensure tensors are base tensors without offsets.
736+
int64_t* self_strides = self_tensor->strides();
737+
int64_t* mat2_strides = mat2_tensor->strides();
738+
int64_t* out_strides = out_tensor->strides();
739+
740+
// Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1]
741+
if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) {
742+
ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. "
743+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
744+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld].",
745+
M * K, K, B, M, K, self_strides[0], self_strides[1], self_strides[2]);
746+
return Error::InvalidArgument;
747+
}
748+
749+
// Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1]
750+
if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) {
751+
ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. "
752+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
753+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld].",
754+
K * N, N, B, K, N, mat2_strides[0], mat2_strides[1], mat2_strides[2]);
755+
return Error::InvalidArgument;
756+
}
757+
758+
// Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1]
759+
if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) {
760+
ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. "
761+
"Only dense row-major layout supported; transposed/view tensors are unsupported. "
762+
"Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld].",
763+
M * N, N, B, M, N, out_strides[0], out_strides[1], out_strides[2]);
764+
return Error::InvalidArgument;
765+
}
766+
767+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Validated shapes and strides. "
768+
"batch=%lld, self=[%lld,%lld], mat2=[%lld,%lld], out=[%lld,%lld]",
769+
B, M, K, K, N, M, N);
770+
771+
// Get Metal stream and device
772+
ETMetalStream* stream = getCurrentMetalStream();
773+
if (!stream) {
774+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get current Metal stream");
775+
return Error::Internal;
776+
}
777+
778+
id<MTLDevice> device = get_metal_device();
779+
if (!device) {
780+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device");
781+
return Error::Internal;
782+
}
783+
(void)device; // Used for validation, consistent with other ops
784+
785+
// Get Metal buffers for input and output tensors
786+
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self");
787+
id<MTLBuffer> mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_bmm_out", "mat2");
788+
id<MTLBuffer> out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_bmm_out", "out");
789+
790+
// Validate buffers are non-null
791+
if (!self_buffer || !mat2_buffer || !out_buffer) {
792+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal buffers. "
793+
"self_buffer=%p, mat2_buffer=%p, out_buffer=%p",
794+
self_buffer, mat2_buffer, out_buffer);
795+
return Error::Internal;
796+
}
797+
798+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Using Metal buffers - self=%p, mat2=%p, out=%p",
799+
self_buffer, mat2_buffer, out_buffer);
800+
801+
// End any existing kernel coalescing to ensure clean state
802+
// (consistent with mm_out and conv pattern)
803+
stream->endKernelCoalescing();
804+
805+
// Map dtype to MPS type and validate support
806+
// Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h)
807+
// FLOAT16 is not in SupportedDTypes enum and is not supported
808+
MPSDataType mps_dtype;
809+
810+
if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
811+
mps_dtype = MPSDataTypeFloat32;
812+
} else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) {
813+
mps_dtype = MPSDataTypeBFloat16;
814+
} else {
815+
ET_LOG(Error, "aoti_torch_mps_bmm_out: Unsupported data type: %d. "
816+
"Supported types: FLOAT32 (%d), BFLOAT16 (%d)",
817+
dtype,
818+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
819+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
820+
return Error::InvalidArgument;
821+
}
822+
823+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: dtype=%d, mps_dtype=%d",
824+
dtype, (int)mps_dtype);
825+
826+
// Define shapes for graph placeholders and tensor data
827+
NSArray<NSNumber*>* selfShape = @[@(B), @(M), @(K)];
828+
NSArray<NSNumber*>* mat2Shape = @[@(B), @(K), @(N)];
829+
NSArray<NSNumber*>* outShape = @[@(B), @(M), @(N)];
830+
831+
// Create cache key for this batched matrix multiplication
832+
// Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag
833+
// This allows reuse when same BMM shape/dtype is called repeatedly
834+
GraphCacheKey cache_key;
835+
cache_key.op_name = "bmm";
836+
cache_key.shape_params = {B, M, K, N};
837+
cache_key.dtype = dtype;
838+
cache_key.transpose_flag = false; // BMM has no transpose handling
839+
840+
// Check if we have a cached graph
841+
MPSGraph* mpsGraph = nullptr;
842+
MPSGraphTensor* outputTensor = nil;
843+
MPSGraphTensor* selfPlaceholder = nil;
844+
MPSGraphTensor* mat2Placeholder = nil;
845+
846+
auto cache_it = graph_cache.find(cache_key);
847+
if (cache_it != graph_cache.end()) {
848+
// Cache hit - reuse compiled graph and tensor references
849+
CachedGraph& cached = cache_it->second;
850+
mpsGraph = cached.graph;
851+
selfPlaceholder = cached.input1;
852+
mat2Placeholder = cached.input2;
853+
outputTensor = cached.output;
854+
855+
cache_stats.hits++;
856+
cache_stats.logStats();
857+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits);
858+
859+
} else {
860+
// Cache miss - create and compile new graph
861+
mpsGraph = [MPSGraph new];
862+
cache_stats.misses++;
863+
cache_stats.logStats();
864+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses);
865+
866+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Creating placeholders with shapes self:[%lld,%lld,%lld] mat2:[%lld,%lld,%lld]",
867+
B, M, K, B, K, N);
868+
869+
// Create 3D placeholders for batched matrices
870+
// These represent the logical shapes for the batched matrix multiplication
871+
selfPlaceholder = [mpsGraph placeholderWithShape:selfShape
872+
dataType:mps_dtype
873+
name:@"self"];
874+
mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape
875+
dataType:mps_dtype
876+
name:@"mat2"];
877+
878+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Created input placeholders");
879+
880+
// MPSGraph matrixMultiplication handles batched case natively when given 3D tensors
881+
// For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N]
882+
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder
883+
secondaryTensor:mat2Placeholder
884+
name:@"bmm_result"];
885+
886+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Successfully created batched matrix multiplication tensor");
887+
888+
// Cache the compiled graph and tensor references for reuse
889+
CachedGraph cached_graph;
890+
cached_graph.graph = mpsGraph;
891+
cached_graph.input1 = selfPlaceholder;
892+
cached_graph.input2 = mat2Placeholder;
893+
cached_graph.input3 = nil; // No third input for BMM
894+
cached_graph.output = outputTensor;
895+
graph_cache[cache_key] = cached_graph;
896+
897+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Cached compiled MPSGraph for future reuse");
898+
} // End of cache miss/hit block
899+
900+
// Create feeds dictionary for graph execution
901+
NSMutableDictionary* feeds = [NSMutableDictionary dictionary];
902+
903+
// Create MPSGraphTensorData objects for input tensors
904+
// These wrap the MTLBuffers with the shape information
905+
// Initialize to nil for safe cleanup in exception path
906+
MPSGraphTensorData* selfData = nil;
907+
MPSGraphTensorData* mat2Data = nil;
908+
MPSGraphTensorData* outputData = nil;
909+
910+
selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer
911+
shape:selfShape
912+
dataType:mps_dtype];
913+
mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer
914+
shape:mat2Shape
915+
dataType:mps_dtype];
916+
917+
feeds[selfPlaceholder] = selfData;
918+
feeds[mat2Placeholder] = mat2Data;
919+
920+
// Create output tensor data
921+
outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer
922+
shape:outShape
923+
dataType:mps_dtype];
924+
925+
// Build results dictionary
926+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
927+
outputTensor: outputData
928+
};
929+
930+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Executing MPSGraph");
931+
932+
// Execute the batched matrix multiplication
933+
@try {
934+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
935+
} @catch (NSException *exception) {
936+
ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s",
937+
[[exception name] UTF8String], [[exception reason] UTF8String]);
938+
// Guard releases against nil
939+
if (selfData) [selfData release];
940+
if (mat2Data) [mat2Data release];
941+
if (outputData) [outputData release];
942+
return Error::Internal;
943+
}
944+
945+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: MPSGraph execution completed successfully");
946+
947+
// Release MPSGraphTensorData objects
948+
[selfData release];
949+
[mat2Data release];
950+
[outputData release];
951+
952+
ET_LOG(Debug, "aoti_torch_mps_bmm_out: Executed successfully for batch size %lld", B);
953+
return Error::Ok;
954+
955+
} catch (const std::exception& e) {
956+
ET_LOG(Error, "aoti_torch_mps_bmm_out exception: %s", e.what());
957+
return Error::Internal;
958+
} catch (...) {
959+
ET_LOG(Error, "aoti_torch_mps_bmm_out: unknown exception");
960+
return Error::Internal;
961+
}
962+
}
963+
}
964+
629965
AOTITorchError aoti_torch_mps_convolution(
630966
AOTITensorHandle input,
631967
AOTITensorHandle weight,

examples/models/parakeet/CMakePresets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"hidden": true,
77
"binaryDir": "${sourceDir}/../../../cmake-out/examples/models/parakeet",
88
"cacheVariables": {
9-
"CMAKE_BUILD_TYPE": "Release",
9+
"CMAKE_BUILD_TYPE": "Debug",
1010
"CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out",
1111
"CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out"
1212
}

0 commit comments

Comments
 (0)