@@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out(
270270
271271 @try {
272272 // Use stream helper to encode and synchronize correctly
273- stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE );
273+ stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT );
274274 } @catch (NSException *exception) {
275275 ET_LOG (Error, " aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s" ,
276276 [[exception name ] UTF8String ], [[exception reason ] UTF8String ]);
@@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out(
279279
280280 ET_LOG (Debug, " aoti_torch_mps_mm_out: MPSGraph execution completed successfully" );
281281
282+ // Release MPSGraph to prevent memory leak
283+ [mpsGraph release ];
284+ mpsGraph = nil ;
285+
286+ [selfData release ];
287+ [mat2Data release ];
288+ [outputData release ];
289+
282290 ET_LOG (Debug, " aoti_torch_mps_mm_out: Executed successfully" );
283291 return Error::Ok;
284292
@@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution(
616624 feeds[inputPlaceholder] = inputData;
617625 feeds[weightPlaceholder] = weightData;
618626
627+ MPSGraphTensorData* biasData = nil ;
628+
619629 // Add bias data to feeds if provided
620630 if (bias_tensor && biasPlaceholder) {
621631 id <MTLBuffer > bias_buffer = get_mtl_buffer (bias_tensor, " aoti_torch_mps_convolution" , " bias" );
622632
623633 NSArray <NSNumber *>* biasShape = @[@(C_out)];
624- MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: bias_buffer
625- shape: biasShape
626- dataType: mps_dtype];
634+ biasData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: bias_buffer
635+ shape: biasShape
636+ dataType: mps_dtype];
627637
628638 feeds[biasPlaceholder] = biasData;
629639 ET_LOG (Debug, " aoti_torch_mps_convolution: Added bias tensor to feeds" );
@@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution(
650660
651661 @try {
652662 // Use stream helper to encode and synchronize correctly
653- stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE );
663+ stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT );
654664 } @catch (NSException *exception) {
655665 ET_LOG (Error, " aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s" ,
656666 [[exception name ] UTF8String ], [[exception reason ] UTF8String ]);
@@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution(
743753 extern std::unordered_map<void *, int32_t > memory_to_n_tensor;
744754 memory_to_n_tensor[tensor_data] = 1 ;
745755
756+ // Release MPSGraph to prevent memory leak
757+ [mpsGraph release ];
758+ mpsGraph = nil ;
759+
760+ [inputData release ];
761+ [weightData release ];
762+ if (biasData) [biasData release ];
763+ [outputData release ];
764+
746765 ET_LOG (Debug, " aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph" , actual_numel);
747766
748767 ET_LOG (Debug, " aoti_torch_mps_convolution: Executed successfully" );
@@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
9921011 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention" );
9931012
9941013 @try {
995- // Check if scaledDotProductAttentionWithQueryTensor is available
996- MPSGraph* testGraph = [MPSGraph new ];
997- if (![testGraph respondsToSelector: @selector (scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name: )]) {
998- ET_LOG (Error, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system" );
999- throw std::runtime_error (" scaledDotProductAttentionWithQueryTensor API not available on this system" );
1000- }
1001- ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available" );
1002-
10031014 // Create MPSGraph for scaled dot product attention
10041015 MPSGraph* mpsGraph = [MPSGraph new ];
10051016 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance" );
@@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12461257 feeds[valuePlaceholder] = valueData;
12471258 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds" );
12481259
1260+ MPSGraphTensorData* maskData = nil ;
1261+
12491262 // Add explicit mask data to feeds if provided
12501263 if (explicitMaskPlaceholder && attn_mask && *attn_mask) {
12511264 auto * mask_tensor = reinterpret_cast <Tensor*>(*attn_mask);
@@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12571270 [maskShapeArray addObject: @(mask_tensor->sizes ()[i])];
12581271 }
12591272
1260- MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: mask_buffer
1261- shape: maskShapeArray
1262- dataType: mps_dtype];
1273+ maskData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: mask_buffer
1274+ shape: maskShapeArray
1275+ dataType: mps_dtype];
12631276 feeds[explicitMaskPlaceholder] = maskData;
12641277 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds" );
12651278 }
@@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12751288
12761289 // Execute via shared stream and keep results on GPU
12771290 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream" );
1278- stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE );
1291+ stream->executeMPSGraph (mpsGraph, feeds, results, SyncType::COMMIT );
12791292 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully" );
12801293
1294+ // Release MPSGraph to prevent memory leak
1295+ [mpsGraph release ];
1296+ mpsGraph = nil ;
1297+
1298+ [queryData release ];
1299+ [keyData release ];
1300+ [valueData release ];
1301+ if (maskData) [maskData release ];
1302+ [outputData release ];
1303+
12811304 } @catch (NSException *exception) {
12821305 ET_LOG (Error, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s" ,
12831306 [[exception name ] UTF8String ], [[exception reason ] UTF8String ]);
0 commit comments