diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index cae8f96c0d2..2ba058de40a 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -680,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev // Commit methods void ETMetalStream::commit() { - if (enableCommitAndContinue_ && commandBuffer_) { - // Use commit-and-continue for better performance - commitAndContinue(); - } else { - flush(); + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit"); + return; } + + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_); + + [commandBuffer_ release]; + commandBuffer_ = nil; } void ETMetalStream::commitAndWait() { diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 94bc3219306..7e1fa66ac7c 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [selfData release]; + [mat2Data release]; + [outputData release]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully"); return Error::Ok; @@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution( feeds[inputPlaceholder] = inputData; feeds[weightPlaceholder] = weightData; + MPSGraphTensorData* biasData = nil; + // Add bias data to feeds if provided if (bias_tensor && biasPlaceholder) { id bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias"); NSArray* biasShape = @[@(C_out)]; - MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer - shape:biasShape - dataType:mps_dtype]; + biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; feeds[biasPlaceholder] = biasData; ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); @@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution( extern std::unordered_map memory_to_n_tensor; memory_to_n_tensor[tensor_data] = 1; + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [inputData release]; + [weightData release]; + if (biasData) [biasData release]; + [outputData release]; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully"); @@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention"); @try { - // Check if scaledDotProductAttentionWithQueryTensor is available - MPSGraph* testGraph = [MPSGraph new]; - if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system"); - throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system"); - } - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available"); - // Create MPSGraph for scaled dot product attention MPSGraph* mpsGraph = [MPSGraph new]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); @@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( feeds[valuePlaceholder] = valueData; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds"); + MPSGraphTensorData* maskData = nil; + // Add explicit mask data to feeds if provided if (explicitMaskPlaceholder && attn_mask && *attn_mask) { auto* mask_tensor = reinterpret_cast(*attn_mask); @@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; } - MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer - shape:maskShapeArray - dataType:mps_dtype]; + maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer + shape:maskShapeArray + dataType:mps_dtype]; feeds[explicitMaskPlaceholder] = maskData; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds"); } @@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( // Execute via shared stream and keep results on GPU ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream"); - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully"); + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [queryData release]; + [keyData release]; + [valueData release]; + if (maskData) [maskData release]; + [outputData release]; + } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]);