@@ -626,6 +626,345 @@ AOTITorchError aoti_torch_mps_mm_out(
626626 }
627627}
628628
629+ AOTITorchError aoti_torch_mps_bmm_out (
630+ AOTITensorHandle out,
631+ AOTITensorHandle self,
632+ AOTITensorHandle mat2) {
633+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Starting with out=%p, self=%p, mat2=%p" ,
634+ out, self, mat2);
635+
636+ // Validate non-null handles
637+ if (!out || !self || !mat2) {
638+ ET_LOG (Error, " aoti_torch_mps_bmm_out: null tensor handles" );
639+ return Error::InvalidArgument;
640+ }
641+
642+ @autoreleasepool {
643+ try {
644+ // Convert AOTITensorHandle to ExecutorTorch tensors
645+ auto out_tensor = reinterpret_cast <Tensor*>(out);
646+ auto self_tensor = reinterpret_cast <Tensor*>(self);
647+ auto mat2_tensor = reinterpret_cast <Tensor*>(mat2);
648+
649+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Converted tensor handles to ET tensors" );
650+
651+ // Validate tensor dimensions - bmm requires 3-D tensors
652+ if (self_tensor->dim () != 3 || mat2_tensor->dim () != 3 || out_tensor->dim () != 3 ) {
653+ ET_LOG (Error, " aoti_torch_mps_bmm_out: tensors must be 3-D. "
654+ " Got self.dim=%zd (shape=[%d,%d,%d]), "
655+ " mat2.dim=%zd (shape=[%d,%d,%d]), "
656+ " out.dim=%zd (shape=[%d,%d,%d])" ,
657+ self_tensor->dim (),
658+ self_tensor->dim () > 0 ? (int )self_tensor->sizes ()[0 ] : 0 ,
659+ self_tensor->dim () > 1 ? (int )self_tensor->sizes ()[1 ] : 0 ,
660+ self_tensor->dim () > 2 ? (int )self_tensor->sizes ()[2 ] : 0 ,
661+ mat2_tensor->dim (),
662+ mat2_tensor->dim () > 0 ? (int )mat2_tensor->sizes ()[0 ] : 0 ,
663+ mat2_tensor->dim () > 1 ? (int )mat2_tensor->sizes ()[1 ] : 0 ,
664+ mat2_tensor->dim () > 2 ? (int )mat2_tensor->sizes ()[2 ] : 0 ,
665+ out_tensor->dim (),
666+ out_tensor->dim () > 0 ? (int )out_tensor->sizes ()[0 ] : 0 ,
667+ out_tensor->dim () > 1 ? (int )out_tensor->sizes ()[1 ] : 0 ,
668+ out_tensor->dim () > 2 ? (int )out_tensor->sizes ()[2 ] : 0 );
669+ return Error::InvalidArgument;
670+ }
671+
672+ int64_t B = self_tensor->sizes ()[0 ]; // batch size
673+ int64_t M = self_tensor->sizes ()[1 ]; // rows of self
674+ int64_t K = self_tensor->sizes ()[2 ]; // cols of self / rows of mat2
675+ int64_t N = mat2_tensor->sizes ()[2 ]; // cols of mat2
676+
677+ // Validate shape constraints
678+ // self: [B, M, K], mat2: [B, K, N], out: [B, M, N]
679+ if (mat2_tensor->sizes ()[0 ] != B) {
680+ ET_LOG (Error, " aoti_torch_mps_bmm_out: batch size mismatch. "
681+ " Expected mat2[0]=%d to match self[0]=%lld. "
682+ " self.shape=[%lld,%lld,%lld], mat2.shape=[%d,%d,%d]" ,
683+ (int )mat2_tensor->sizes ()[0 ], (long long )B,
684+ (long long )B, (long long )M, (long long )K,
685+ (int )mat2_tensor->sizes ()[0 ], (int )mat2_tensor->sizes ()[1 ], (int )mat2_tensor->sizes ()[2 ]);
686+ return Error::InvalidArgument;
687+ }
688+
689+ if (mat2_tensor->sizes ()[1 ] != K) {
690+ ET_LOG (Error, " aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. "
691+ " Expected mat2[1]=%d to match self[2]=%lld. "
692+ " Cannot multiply [%lld,%lld,%lld] @ [%d,%d,%d]" ,
693+ (int )mat2_tensor->sizes ()[1 ], (long long )K,
694+ (long long )B, (long long )M, (long long )K,
695+ (int )mat2_tensor->sizes ()[0 ], (int )mat2_tensor->sizes ()[1 ], (int )mat2_tensor->sizes ()[2 ]);
696+ return Error::InvalidArgument;
697+ }
698+
699+ if (out_tensor->sizes ()[0 ] != B || out_tensor->sizes ()[1 ] != M || out_tensor->sizes ()[2 ] != N) {
700+ ET_LOG (Error, " aoti_torch_mps_bmm_out: output shape mismatch. "
701+ " Expected out.shape=[%lld,%lld,%lld], got [%d,%d,%d]" ,
702+ (long long )B, (long long )M, (long long )N,
703+ (int )out_tensor->sizes ()[0 ], (int )out_tensor->sizes ()[1 ], (int )out_tensor->sizes ()[2 ]);
704+ return Error::InvalidArgument;
705+ }
706+
707+ // Validate dtype consistency
708+ int32_t self_dtype = static_cast <int32_t >(self_tensor->scalar_type ());
709+ int32_t mat2_dtype = static_cast <int32_t >(mat2_tensor->scalar_type ());
710+ int32_t out_dtype = static_cast <int32_t >(out_tensor->scalar_type ());
711+
712+ if (self_dtype != mat2_dtype || self_dtype != out_dtype) {
713+ ET_LOG (Error, " aoti_torch_mps_bmm_out: dtype mismatch. "
714+ " All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d" ,
715+ self_dtype, mat2_dtype, out_dtype);
716+ return Error::InvalidArgument;
717+ }
718+
719+ int32_t dtype = self_dtype;
720+
721+ // Validate layout: BMM requires strictly contiguous 3D tensors
722+ // For shape [B, M, K], contiguous strides MUST be [M*K, K, 1]
723+ //
724+ // Why strict contiguity is required:
725+ // - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer
726+ // as containing dense row-major data for the given shape
727+ // - Non-contiguous layouts (transposed, views with strides, etc.) have different
728+ // memory layouts that don't match what MPS expects
729+ // - This would result in SILENT WRONG RESULTS
730+ // - This is an _out op: we must NOT create implicit copies
731+ // - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported)
732+ //
733+ // Limitation: This implementation does not explicitly check storage offset (no API available).
734+ // Tensors with non-zero storage offsets are not explicitly rejected but may work if they
735+ // happen to have contiguous strides. Users should ensure tensors are base tensors without offsets.
736+ auto self_strides = self_tensor->strides ();
737+ auto mat2_strides = mat2_tensor->strides ();
738+ auto out_strides = out_tensor->strides ();
739+
740+ // Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1]
741+ if (self_strides[2 ] != 1 || self_strides[1 ] != K || self_strides[0 ] != M * K) {
742+ ET_LOG (Error, " aoti_torch_mps_bmm_out: self tensor must be contiguous. "
743+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
744+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d]." ,
745+ (long long )(M * K), (long long )K, (long long )B, (long long )M, (long long )K,
746+ self_strides[0 ], self_strides[1 ], self_strides[2 ]);
747+ return Error::InvalidArgument;
748+ }
749+
750+ // Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1]
751+ if (mat2_strides[2 ] != 1 || mat2_strides[1 ] != N || mat2_strides[0 ] != K * N) {
752+ ET_LOG (Error, " aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. "
753+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
754+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d]." ,
755+ (long long )(K * N), (long long )N, (long long )B, (long long )K, (long long )N,
756+ mat2_strides[0 ], mat2_strides[1 ], mat2_strides[2 ]);
757+ return Error::InvalidArgument;
758+ }
759+
760+ // Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1]
761+ if (out_strides[2 ] != 1 || out_strides[1 ] != N || out_strides[0 ] != M * N) {
762+ ET_LOG (Error, " aoti_torch_mps_bmm_out: out tensor must be contiguous. "
763+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
764+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d]." ,
765+ (long long )(M * N), (long long )N, (long long )B, (long long )M, (long long )N,
766+ out_strides[0 ], out_strides[1 ], out_strides[2 ]);
767+ return Error::InvalidArgument;
768+ }
769+
770+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Validated shapes and strides. "
771+ " batch=%lld, self=[%lld,%lld], mat2=[%lld,%lld], out=[%lld,%lld]" ,
772+ B, M, K, K, N, M, N);
773+
774+ // Get Metal stream and device
775+ ETMetalStream* stream = getCurrentMetalStream ();
776+ if (!stream) {
777+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get current Metal stream" );
778+ return Error::Internal;
779+ }
780+
781+ id <MTLDevice > device = get_metal_device ();
782+ if (!device) {
783+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get Metal device" );
784+ return Error::Internal;
785+ }
786+ (void )device; // Used for validation, consistent with other ops
787+
788+ // Get Metal buffers for input and output tensors
789+ id <MTLBuffer > self_buffer = get_mtl_buffer (self_tensor, " aoti_torch_mps_bmm_out" , " self" );
790+ id <MTLBuffer > mat2_buffer = get_mtl_buffer (mat2_tensor, " aoti_torch_mps_bmm_out" , " mat2" );
791+ id <MTLBuffer > out_buffer = get_mtl_buffer (out_tensor, " aoti_torch_mps_bmm_out" , " out" );
792+
793+ // Validate buffers are non-null
794+ if (!self_buffer || !mat2_buffer || !out_buffer) {
795+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get Metal buffers. "
796+ " self_buffer=%p, mat2_buffer=%p, out_buffer=%p" ,
797+ self_buffer, mat2_buffer, out_buffer);
798+ return Error::Internal;
799+ }
800+
801+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Using Metal buffers - self=%p, mat2=%p, out=%p" ,
802+ self_buffer, mat2_buffer, out_buffer);
803+
804+ // End any existing kernel coalescing to ensure clean state
805+ // (consistent with mm_out and conv pattern)
806+ stream->endKernelCoalescing ();
807+
808+ // Map dtype to MPS type and validate support
809+ // Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h)
810+ // FLOAT16 is not in SupportedDTypes enum and is not supported
811+ MPSDataType mps_dtype;
812+
813+ if (dtype == static_cast <int32_t >(SupportedDTypes::FLOAT32)) {
814+ mps_dtype = MPSDataTypeFloat32;
815+ } else if (dtype == static_cast <int32_t >(SupportedDTypes::BFLOAT16)) {
816+ mps_dtype = MPSDataTypeBFloat16;
817+ } else {
818+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Unsupported data type: %d. "
819+ " Supported types: FLOAT32 (%d), BFLOAT16 (%d)" ,
820+ dtype,
821+ static_cast <int32_t >(SupportedDTypes::FLOAT32),
822+ static_cast <int32_t >(SupportedDTypes::BFLOAT16));
823+ return Error::InvalidArgument;
824+ }
825+
826+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: dtype=%d, mps_dtype=%d" ,
827+ dtype, (int )mps_dtype);
828+
829+ // Define shapes for graph placeholders and tensor data
830+ NSArray <NSNumber *>* selfShape = @[@(B), @(M), @(K)];
831+ NSArray <NSNumber *>* mat2Shape = @[@(B), @(K), @(N)];
832+ NSArray <NSNumber *>* outShape = @[@(B), @(M), @(N)];
833+
834+ // Create cache key for this batched matrix multiplication
835+ // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag
836+ // This allows reuse when same BMM shape/dtype is called repeatedly
837+ GraphCacheKey cache_key;
838+ cache_key.op_name = " bmm" ;
839+ cache_key.shape_params = {B, M, K, N};
840+ cache_key.dtype = dtype;
841+ cache_key.transpose_flag = false ; // BMM has no transpose handling
842+
843+ // Check if we have a cached graph
844+ MPSGraph* mpsGraph = nullptr ;
845+ MPSGraphTensor* outputTensor = nil ;
846+ MPSGraphTensor* selfPlaceholder = nil ;
847+ MPSGraphTensor* mat2Placeholder = nil ;
848+
849+ auto cache_it = graph_cache.find (cache_key);
850+ if (cache_it != graph_cache.end ()) {
851+ // Cache hit - reuse compiled graph and tensor references
852+ CachedGraph& cached = cache_it->second ;
853+ mpsGraph = cached.graph ;
854+ selfPlaceholder = cached.input1 ;
855+ mat2Placeholder = cached.input2 ;
856+ outputTensor = cached.output ;
857+
858+ cache_stats.hits ++;
859+ cache_stats.logStats ();
860+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Using cached MPSGraph (cache hit, %zu total hits)" , cache_stats.hits );
861+
862+ } else {
863+ // Cache miss - create and compile new graph
864+ mpsGraph = [MPSGraph new ];
865+ cache_stats.misses ++;
866+ cache_stats.logStats ();
867+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Created new MPSGraph instance (cache miss, %zu total misses)" , cache_stats.misses );
868+
869+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Creating placeholders with shapes self:[%lld,%lld,%lld] mat2:[%lld,%lld,%lld]" ,
870+ B, M, K, B, K, N);
871+
872+ // Create 3D placeholders for batched matrices
873+ // These represent the logical shapes for the batched matrix multiplication
874+ selfPlaceholder = [mpsGraph placeholderWithShape: selfShape
875+ dataType: mps_dtype
876+ name: @" self" ];
877+ mat2Placeholder = [mpsGraph placeholderWithShape: mat2Shape
878+ dataType: mps_dtype
879+ name: @" mat2" ];
880+
881+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Created input placeholders" );
882+
883+ // MPSGraph matrixMultiplication handles batched case natively when given 3D tensors
884+ // For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N]
885+ outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: selfPlaceholder
886+ secondaryTensor: mat2Placeholder
887+ name: @" bmm_result" ];
888+
889+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Successfully created batched matrix multiplication tensor" );
890+
891+ // Cache the compiled graph and tensor references for reuse
892+ CachedGraph cached_graph;
893+ cached_graph.graph = mpsGraph;
894+ cached_graph.input1 = selfPlaceholder;
895+ cached_graph.input2 = mat2Placeholder;
896+ cached_graph.input3 = nil ; // No third input for BMM
897+ cached_graph.output = outputTensor;
898+ graph_cache[cache_key] = cached_graph;
899+
900+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Cached compiled MPSGraph for future reuse" );
901+ } // End of cache miss/hit block
902+
903+ // Create feeds dictionary for graph execution
904+ NSMutableDictionary * feeds = [NSMutableDictionary dictionary ];
905+
906+ // Create MPSGraphTensorData objects for input tensors
907+ // These wrap the MTLBuffers with the shape information
908+ // Initialize to nil for safe cleanup in exception path
909+ MPSGraphTensorData* selfData = nil ;
910+ MPSGraphTensorData* mat2Data = nil ;
911+ MPSGraphTensorData* outputData = nil ;
912+
913+ selfData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer
914+ shape: selfShape
915+ dataType: mps_dtype];
916+ mat2Data = [[MPSGraphTensorData alloc ] initWithMTLBuffer: mat2_buffer
917+ shape: mat2Shape
918+ dataType: mps_dtype];
919+
920+ feeds[selfPlaceholder] = selfData;
921+ feeds[mat2Placeholder] = mat2Data;
922+
923+ // Create output tensor data
924+ outputData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: out_buffer
925+ shape: outShape
926+ dataType: mps_dtype];
927+
928+ // Build results dictionary
929+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
930+ outputTensor: outputData
931+ };
932+
933+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Executing MPSGraph" );
934+
935+ // Execute the batched matrix multiplication
936+ @try {
937+ stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT);
938+ } @catch (NSException *exception) {
939+ ET_LOG (Error, " aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s" ,
940+ [[exception name ] UTF8String ], [[exception reason ] UTF8String ]);
941+ // Guard releases against nil
942+ if (selfData) [selfData release ];
943+ if (mat2Data) [mat2Data release ];
944+ if (outputData) [outputData release ];
945+ return Error::Internal;
946+ }
947+
948+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: MPSGraph execution completed successfully" );
949+
950+ // Release MPSGraphTensorData objects
951+ [selfData release ];
952+ [mat2Data release ];
953+ [outputData release ];
954+
955+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Executed successfully for batch size %lld" , B);
956+ return Error::Ok;
957+
958+ } catch (const std::exception& e) {
959+ ET_LOG (Error, " aoti_torch_mps_bmm_out exception: %s" , e.what ());
960+ return Error::Internal;
961+ } catch (...) {
962+ ET_LOG (Error, " aoti_torch_mps_bmm_out: unknown exception" );
963+ return Error::Internal;
964+ }
965+ }
966+ }
967+
629968AOTITorchError aoti_torch_mps_convolution (
630969 AOTITensorHandle input,
631970 AOTITensorHandle weight,
0 commit comments