Skip to content

Commit 526eb18

Browse files
Metal backend: eliminate memory leak (#15343)
This pull request refactors the Metal backend's stream commit logic and improves memory management for MPSGraph operations. The most significant changes are the rewrite of the "commit" SyncType to always release the command buffer after commit, and the addition of explicit resource releases to prevent memory leaks after MPSGraph execution. **Stream commit logic changes:** * The `ETMetalStream::commit` method now always commits and releases the command buffer directly, removing support for the "commit and continue" optimization and adding error logging if no command buffer is present. * All MPSGraph execution calls in the ops shim now use `SyncType::COMMIT` instead of `SyncType::COMMIT_AND_CONTINUE`. **Memory management improvements:** * After MPSGraph execution in all ops (`mm_out`, `convolution`, and `scaled_dot_product_attention_math_for_mps`), the code now explicitly releases the MPSGraph and all associated `MPSGraphTensorData` objects to prevent memory leaks. **Code simplification and clarity:** * Unused API availability check for `scaledDotProductAttentionWithQueryTensor` was removed, streamlining the attention math implementation. * Tensor data variables (`biasData`, `maskData`) are now declared outside conditional blocks for proper release and scope management.
1 parent 81a3acc commit 526eb18

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
680680

681681
// Commit methods
682682
void ETMetalStream::commit() {
683-
if (enableCommitAndContinue_ && commandBuffer_) {
684-
// Use commit-and-continue for better performance
685-
commitAndContinue();
686-
} else {
687-
flush();
683+
if (!commandBuffer_) {
684+
ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit");
685+
return;
688686
}
687+
688+
[commandBuffer_ commit];
689+
ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_);
690+
691+
[commandBuffer_ release];
692+
commandBuffer_ = nil;
689693
}
690694

691695
void ETMetalStream::commitAndWait() {

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out(
270270

271271
@try {
272272
// Use stream helper to encode and synchronize correctly
273-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
273+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
274274
} @catch (NSException *exception) {
275275
ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s",
276276
[[exception name] UTF8String], [[exception reason] UTF8String]);
@@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out(
279279

280280
ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully");
281281

282+
// Release MPSGraph to prevent memory leak
283+
[mpsGraph release];
284+
mpsGraph = nil;
285+
286+
[selfData release];
287+
[mat2Data release];
288+
[outputData release];
289+
282290
ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully");
283291
return Error::Ok;
284292

@@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution(
616624
feeds[inputPlaceholder] = inputData;
617625
feeds[weightPlaceholder] = weightData;
618626

627+
MPSGraphTensorData* biasData = nil;
628+
619629
// Add bias data to feeds if provided
620630
if (bias_tensor && biasPlaceholder) {
621631
id<MTLBuffer> bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias");
622632

623633
NSArray<NSNumber*>* biasShape = @[@(C_out)];
624-
MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
625-
shape:biasShape
626-
dataType:mps_dtype];
634+
biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
635+
shape:biasShape
636+
dataType:mps_dtype];
627637

628638
feeds[biasPlaceholder] = biasData;
629639
ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds");
@@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution(
650660

651661
@try {
652662
// Use stream helper to encode and synchronize correctly
653-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
663+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
654664
} @catch (NSException *exception) {
655665
ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s",
656666
[[exception name] UTF8String], [[exception reason] UTF8String]);
@@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution(
743753
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
744754
memory_to_n_tensor[tensor_data] = 1;
745755

756+
// Release MPSGraph to prevent memory leak
757+
[mpsGraph release];
758+
mpsGraph = nil;
759+
760+
[inputData release];
761+
[weightData release];
762+
if (biasData) [biasData release];
763+
[outputData release];
764+
746765
ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);
747766

748767
ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully");
@@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
9921011
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention");
9931012

9941013
@try {
995-
// Check if scaledDotProductAttentionWithQueryTensor is available
996-
MPSGraph* testGraph = [MPSGraph new];
997-
if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) {
998-
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system");
999-
throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system");
1000-
}
1001-
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available");
1002-
10031014
// Create MPSGraph for scaled dot product attention
10041015
MPSGraph* mpsGraph = [MPSGraph new];
10051016
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance");
@@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12461257
feeds[valuePlaceholder] = valueData;
12471258
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds");
12481259

1260+
MPSGraphTensorData* maskData = nil;
1261+
12491262
// Add explicit mask data to feeds if provided
12501263
if (explicitMaskPlaceholder && attn_mask && *attn_mask) {
12511264
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);
@@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12571270
[maskShapeArray addObject:@(mask_tensor->sizes()[i])];
12581271
}
12591272

1260-
MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
1261-
shape:maskShapeArray
1262-
dataType:mps_dtype];
1273+
maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
1274+
shape:maskShapeArray
1275+
dataType:mps_dtype];
12631276
feeds[explicitMaskPlaceholder] = maskData;
12641277
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds");
12651278
}
@@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12751288

12761289
// Execute via shared stream and keep results on GPU
12771290
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream");
1278-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
1291+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
12791292
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully");
12801293

1294+
// Release MPSGraph to prevent memory leak
1295+
[mpsGraph release];
1296+
mpsGraph = nil;
1297+
1298+
[queryData release];
1299+
[keyData release];
1300+
[valueData release];
1301+
if (maskData) [maskData release];
1302+
[outputData release];
1303+
12811304
} @catch (NSException *exception) {
12821305
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s",
12831306
[[exception name] UTF8String], [[exception reason] UTF8String]);

0 commit comments

Comments
 (0)