@@ -626,6 +626,342 @@ AOTITorchError aoti_torch_mps_mm_out(
626626 }
627627}
628628
629+ AOTITorchError aoti_torch_mps_bmm_out (
630+ AOTITensorHandle out,
631+ AOTITensorHandle self,
632+ AOTITensorHandle mat2) {
633+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Starting with out=%p, self=%p, mat2=%p" ,
634+ out, self, mat2);
635+
636+ // Validate non-null handles
637+ if (!out || !self || !mat2) {
638+ ET_LOG (Error, " aoti_torch_mps_bmm_out: null tensor handles" );
639+ return Error::InvalidArgument;
640+ }
641+
642+ @autoreleasepool {
643+ try {
644+ // Convert AOTITensorHandle to ExecutorTorch tensors
645+ auto out_tensor = reinterpret_cast <Tensor*>(out);
646+ auto self_tensor = reinterpret_cast <Tensor*>(self);
647+ auto mat2_tensor = reinterpret_cast <Tensor*>(mat2);
648+
649+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Converted tensor handles to ET tensors" );
650+
651+ // Validate tensor dimensions - bmm requires 3-D tensors
652+ if (self_tensor->dim () != 3 || mat2_tensor->dim () != 3 || out_tensor->dim () != 3 ) {
653+ ET_LOG (Error, " aoti_torch_mps_bmm_out: tensors must be 3-D. "
654+ " Got self.dim=%d (shape=[%lld,%lld,%lld]), "
655+ " mat2.dim=%d (shape=[%lld,%lld,%lld]), "
656+ " out.dim=%d (shape=[%lld,%lld,%lld])" ,
657+ self_tensor->dim (),
658+ self_tensor->dim () > 0 ? self_tensor->sizes ()[0 ] : 0 ,
659+ self_tensor->dim () > 1 ? self_tensor->sizes ()[1 ] : 0 ,
660+ self_tensor->dim () > 2 ? self_tensor->sizes ()[2 ] : 0 ,
661+ mat2_tensor->dim (),
662+ mat2_tensor->dim () > 0 ? mat2_tensor->sizes ()[0 ] : 0 ,
663+ mat2_tensor->dim () > 1 ? mat2_tensor->sizes ()[1 ] : 0 ,
664+ mat2_tensor->dim () > 2 ? mat2_tensor->sizes ()[2 ] : 0 ,
665+ out_tensor->dim (),
666+ out_tensor->dim () > 0 ? out_tensor->sizes ()[0 ] : 0 ,
667+ out_tensor->dim () > 1 ? out_tensor->sizes ()[1 ] : 0 ,
668+ out_tensor->dim () > 2 ? out_tensor->sizes ()[2 ] : 0 );
669+ return Error::InvalidArgument;
670+ }
671+
672+ int64_t B = self_tensor->sizes ()[0 ]; // batch size
673+ int64_t M = self_tensor->sizes ()[1 ]; // rows of self
674+ int64_t K = self_tensor->sizes ()[2 ]; // cols of self / rows of mat2
675+ int64_t N = mat2_tensor->sizes ()[2 ]; // cols of mat2
676+
677+ // Validate shape constraints
678+ // self: [B, M, K], mat2: [B, K, N], out: [B, M, N]
679+ if (mat2_tensor->sizes ()[0 ] != B) {
680+ ET_LOG (Error, " aoti_torch_mps_bmm_out: batch size mismatch. "
681+ " Expected mat2[0]=%lld to match self[0]=%lld. "
682+ " self.shape=[%lld,%lld,%lld], mat2.shape=[%lld,%lld,%lld]" ,
683+ mat2_tensor->sizes ()[0 ], B,
684+ B, M, K,
685+ mat2_tensor->sizes ()[0 ], mat2_tensor->sizes ()[1 ], mat2_tensor->sizes ()[2 ]);
686+ return Error::InvalidArgument;
687+ }
688+
689+ if (mat2_tensor->sizes ()[1 ] != K) {
690+ ET_LOG (Error, " aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. "
691+ " Expected mat2[1]=%lld to match self[2]=%lld. "
692+ " Cannot multiply [%lld,%lld,%lld] @ [%lld,%lld,%lld]" ,
693+ mat2_tensor->sizes ()[1 ], K,
694+ B, M, K,
695+ mat2_tensor->sizes ()[0 ], mat2_tensor->sizes ()[1 ], mat2_tensor->sizes ()[2 ]);
696+ return Error::InvalidArgument;
697+ }
698+
699+ if (out_tensor->sizes ()[0 ] != B || out_tensor->sizes ()[1 ] != M || out_tensor->sizes ()[2 ] != N) {
700+ ET_LOG (Error, " aoti_torch_mps_bmm_out: output shape mismatch. "
701+ " Expected out.shape=[%lld,%lld,%lld], got [%lld,%lld,%lld]" ,
702+ B, M, N,
703+ out_tensor->sizes ()[0 ], out_tensor->sizes ()[1 ], out_tensor->sizes ()[2 ]);
704+ return Error::InvalidArgument;
705+ }
706+
707+ // Validate dtype consistency
708+ int32_t self_dtype = static_cast <int32_t >(self_tensor->scalar_type ());
709+ int32_t mat2_dtype = static_cast <int32_t >(mat2_tensor->scalar_type ());
710+ int32_t out_dtype = static_cast <int32_t >(out_tensor->scalar_type ());
711+
712+ if (self_dtype != mat2_dtype || self_dtype != out_dtype) {
713+ ET_LOG (Error, " aoti_torch_mps_bmm_out: dtype mismatch. "
714+ " All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d" ,
715+ self_dtype, mat2_dtype, out_dtype);
716+ return Error::InvalidArgument;
717+ }
718+
719+ int32_t dtype = self_dtype;
720+
721+ // Validate layout: BMM requires strictly contiguous 3D tensors
722+ // For shape [B, M, K], contiguous strides MUST be [M*K, K, 1]
723+ //
724+ // Why strict contiguity is required:
725+ // - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer
726+ // as containing dense row-major data for the given shape
727+ // - Non-contiguous layouts (transposed, views with strides, etc.) have different
728+ // memory layouts that don't match what MPS expects
729+ // - This would result in SILENT WRONG RESULTS
730+ // - This is an _out op: we must NOT create implicit copies
731+ // - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported)
732+ //
733+ // Limitation: This implementation does not explicitly check storage offset (no API available).
734+ // Tensors with non-zero storage offsets are not explicitly rejected but may work if they
735+ // happen to have contiguous strides. Users should ensure tensors are base tensors without offsets.
736+ int64_t * self_strides = self_tensor->strides ();
737+ int64_t * mat2_strides = mat2_tensor->strides ();
738+ int64_t * out_strides = out_tensor->strides ();
739+
740+ // Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1]
741+ if (self_strides[2 ] != 1 || self_strides[1 ] != K || self_strides[0 ] != M * K) {
742+ ET_LOG (Error, " aoti_torch_mps_bmm_out: self tensor must be contiguous. "
743+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
744+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld]." ,
745+ M * K, K, B, M, K, self_strides[0 ], self_strides[1 ], self_strides[2 ]);
746+ return Error::InvalidArgument;
747+ }
748+
749+ // Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1]
750+ if (mat2_strides[2 ] != 1 || mat2_strides[1 ] != N || mat2_strides[0 ] != K * N) {
751+ ET_LOG (Error, " aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. "
752+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
753+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld]." ,
754+ K * N, N, B, K, N, mat2_strides[0 ], mat2_strides[1 ], mat2_strides[2 ]);
755+ return Error::InvalidArgument;
756+ }
757+
758+ // Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1]
759+ if (out_strides[2 ] != 1 || out_strides[1 ] != N || out_strides[0 ] != M * N) {
760+ ET_LOG (Error, " aoti_torch_mps_bmm_out: out tensor must be contiguous. "
761+ " Only dense row-major layout supported; transposed/view tensors are unsupported. "
762+ " Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%lld,%lld,%lld]." ,
763+ M * N, N, B, M, N, out_strides[0 ], out_strides[1 ], out_strides[2 ]);
764+ return Error::InvalidArgument;
765+ }
766+
767+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Validated shapes and strides. "
768+ " batch=%lld, self=[%lld,%lld], mat2=[%lld,%lld], out=[%lld,%lld]" ,
769+ B, M, K, K, N, M, N);
770+
771+ // Get Metal stream and device
772+ ETMetalStream* stream = getCurrentMetalStream ();
773+ if (!stream) {
774+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get current Metal stream" );
775+ return Error::Internal;
776+ }
777+
778+ id <MTLDevice > device = get_metal_device ();
779+ if (!device) {
780+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get Metal device" );
781+ return Error::Internal;
782+ }
783+ (void )device; // Used for validation, consistent with other ops
784+
785+ // Get Metal buffers for input and output tensors
786+ id <MTLBuffer > self_buffer = get_mtl_buffer (self_tensor, " aoti_torch_mps_bmm_out" , " self" );
787+ id <MTLBuffer > mat2_buffer = get_mtl_buffer (mat2_tensor, " aoti_torch_mps_bmm_out" , " mat2" );
788+ id <MTLBuffer > out_buffer = get_mtl_buffer (out_tensor, " aoti_torch_mps_bmm_out" , " out" );
789+
790+ // Validate buffers are non-null
791+ if (!self_buffer || !mat2_buffer || !out_buffer) {
792+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Failed to get Metal buffers. "
793+ " self_buffer=%p, mat2_buffer=%p, out_buffer=%p" ,
794+ self_buffer, mat2_buffer, out_buffer);
795+ return Error::Internal;
796+ }
797+
798+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Using Metal buffers - self=%p, mat2=%p, out=%p" ,
799+ self_buffer, mat2_buffer, out_buffer);
800+
801+ // End any existing kernel coalescing to ensure clean state
802+ // (consistent with mm_out and conv pattern)
803+ stream->endKernelCoalescing ();
804+
805+ // Map dtype to MPS type and validate support
806+ // Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h)
807+ // FLOAT16 is not in SupportedDTypes enum and is not supported
808+ MPSDataType mps_dtype;
809+
810+ if (dtype == static_cast <int32_t >(SupportedDTypes::FLOAT32)) {
811+ mps_dtype = MPSDataTypeFloat32;
812+ } else if (dtype == static_cast <int32_t >(SupportedDTypes::BFLOAT16)) {
813+ mps_dtype = MPSDataTypeBFloat16;
814+ } else {
815+ ET_LOG (Error, " aoti_torch_mps_bmm_out: Unsupported data type: %d. "
816+ " Supported types: FLOAT32 (%d), BFLOAT16 (%d)" ,
817+ dtype,
818+ static_cast <int32_t >(SupportedDTypes::FLOAT32),
819+ static_cast <int32_t >(SupportedDTypes::BFLOAT16));
820+ return Error::InvalidArgument;
821+ }
822+
823+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: dtype=%d, mps_dtype=%d" ,
824+ dtype, (int )mps_dtype);
825+
826+ // Define shapes for graph placeholders and tensor data
827+ NSArray <NSNumber *>* selfShape = @[@(B), @(M), @(K)];
828+ NSArray <NSNumber *>* mat2Shape = @[@(B), @(K), @(N)];
829+ NSArray <NSNumber *>* outShape = @[@(B), @(M), @(N)];
830+
831+ // Create cache key for this batched matrix multiplication
832+ // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag
833+ // This allows reuse when same BMM shape/dtype is called repeatedly
834+ GraphCacheKey cache_key;
835+ cache_key.op_name = " bmm" ;
836+ cache_key.shape_params = {B, M, K, N};
837+ cache_key.dtype = dtype;
838+ cache_key.transpose_flag = false ; // BMM has no transpose handling
839+
840+ // Check if we have a cached graph
841+ MPSGraph* mpsGraph = nullptr ;
842+ MPSGraphTensor* outputTensor = nil ;
843+ MPSGraphTensor* selfPlaceholder = nil ;
844+ MPSGraphTensor* mat2Placeholder = nil ;
845+
846+ auto cache_it = graph_cache.find (cache_key);
847+ if (cache_it != graph_cache.end ()) {
848+ // Cache hit - reuse compiled graph and tensor references
849+ CachedGraph& cached = cache_it->second ;
850+ mpsGraph = cached.graph ;
851+ selfPlaceholder = cached.input1 ;
852+ mat2Placeholder = cached.input2 ;
853+ outputTensor = cached.output ;
854+
855+ cache_stats.hits ++;
856+ cache_stats.logStats ();
857+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Using cached MPSGraph (cache hit, %zu total hits)" , cache_stats.hits );
858+
859+ } else {
860+ // Cache miss - create and compile new graph
861+ mpsGraph = [MPSGraph new ];
862+ cache_stats.misses ++;
863+ cache_stats.logStats ();
864+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Created new MPSGraph instance (cache miss, %zu total misses)" , cache_stats.misses );
865+
866+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Creating placeholders with shapes self:[%lld,%lld,%lld] mat2:[%lld,%lld,%lld]" ,
867+ B, M, K, B, K, N);
868+
869+ // Create 3D placeholders for batched matrices
870+ // These represent the logical shapes for the batched matrix multiplication
871+ selfPlaceholder = [mpsGraph placeholderWithShape: selfShape
872+ dataType: mps_dtype
873+ name: @" self" ];
874+ mat2Placeholder = [mpsGraph placeholderWithShape: mat2Shape
875+ dataType: mps_dtype
876+ name: @" mat2" ];
877+
878+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Created input placeholders" );
879+
880+ // MPSGraph matrixMultiplication handles batched case natively when given 3D tensors
881+ // For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N]
882+ outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: selfPlaceholder
883+ secondaryTensor: mat2Placeholder
884+ name: @" bmm_result" ];
885+
886+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Successfully created batched matrix multiplication tensor" );
887+
888+ // Cache the compiled graph and tensor references for reuse
889+ CachedGraph cached_graph;
890+ cached_graph.graph = mpsGraph;
891+ cached_graph.input1 = selfPlaceholder;
892+ cached_graph.input2 = mat2Placeholder;
893+ cached_graph.input3 = nil ; // No third input for BMM
894+ cached_graph.output = outputTensor;
895+ graph_cache[cache_key] = cached_graph;
896+
897+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Cached compiled MPSGraph for future reuse" );
898+ } // End of cache miss/hit block
899+
900+ // Create feeds dictionary for graph execution
901+ NSMutableDictionary * feeds = [NSMutableDictionary dictionary ];
902+
903+ // Create MPSGraphTensorData objects for input tensors
904+ // These wrap the MTLBuffers with the shape information
905+ // Initialize to nil for safe cleanup in exception path
906+ MPSGraphTensorData* selfData = nil ;
907+ MPSGraphTensorData* mat2Data = nil ;
908+ MPSGraphTensorData* outputData = nil ;
909+
910+ selfData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer
911+ shape: selfShape
912+ dataType: mps_dtype];
913+ mat2Data = [[MPSGraphTensorData alloc ] initWithMTLBuffer: mat2_buffer
914+ shape: mat2Shape
915+ dataType: mps_dtype];
916+
917+ feeds[selfPlaceholder] = selfData;
918+ feeds[mat2Placeholder] = mat2Data;
919+
920+ // Create output tensor data
921+ outputData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: out_buffer
922+ shape: outShape
923+ dataType: mps_dtype];
924+
925+ // Build results dictionary
926+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
927+ outputTensor: outputData
928+ };
929+
930+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Executing MPSGraph" );
931+
932+ // Execute the batched matrix multiplication
933+ @try {
934+ stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT);
935+ } @catch (NSException *exception) {
936+ ET_LOG (Error, " aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s" ,
937+ [[exception name ] UTF8String ], [[exception reason ] UTF8String ]);
938+ // Guard releases against nil
939+ if (selfData) [selfData release ];
940+ if (mat2Data) [mat2Data release ];
941+ if (outputData) [outputData release ];
942+ return Error::Internal;
943+ }
944+
945+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: MPSGraph execution completed successfully" );
946+
947+ // Release MPSGraphTensorData objects
948+ [selfData release ];
949+ [mat2Data release ];
950+ [outputData release ];
951+
952+ ET_LOG (Debug, " aoti_torch_mps_bmm_out: Executed successfully for batch size %lld" , B);
953+ return Error::Ok;
954+
955+ } catch (const std::exception& e) {
956+ ET_LOG (Error, " aoti_torch_mps_bmm_out exception: %s" , e.what ());
957+ return Error::Internal;
958+ } catch (...) {
959+ ET_LOG (Error, " aoti_torch_mps_bmm_out: unknown exception" );
960+ return Error::Internal;
961+ }
962+ }
963+ }
964+
629965AOTITorchError aoti_torch_mps_convolution (
630966 AOTITensorHandle input,
631967 AOTITensorHandle weight,
0 commit comments