Skip to content

Commit 136f908

Browse files
Update
[ghstack-poisoned]
1 parent 00ba522 commit 136f908

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
680680

681681
// Commit methods
682682
void ETMetalStream::commit() {
683-
if (enableCommitAndContinue_ && commandBuffer_) {
684-
// Use commit-and-continue for better performance
685-
commitAndContinue();
686-
} else {
687-
flush();
683+
if (!commandBuffer_) {
684+
ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit");
685+
return;
688686
}
687+
688+
[commandBuffer_ commit];
689+
ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_);
690+
691+
[commandBuffer_ release];
692+
commandBuffer_ = nil;
689693
}
690694

691695
void ETMetalStream::commitAndWait() {

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out(
270270

271271
@try {
272272
// Use stream helper to encode and synchronize correctly
273-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
273+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
274274
} @catch (NSException *exception) {
275275
ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s",
276276
[[exception name] UTF8String], [[exception reason] UTF8String]);
@@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out(
279279

280280
ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully");
281281

282+
// Release MPSGraph to prevent memory leak
283+
[mpsGraph release];
284+
mpsGraph = nil;
285+
286+
[selfData release];
287+
[mat2Data release];
288+
[outputData release];
289+
282290
ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully");
283291
return Error::Ok;
284292

@@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution(
616624
feeds[inputPlaceholder] = inputData;
617625
feeds[weightPlaceholder] = weightData;
618626

627+
MPSGraphTensorData* biasData = nil;
628+
619629
// Add bias data to feeds if provided
620630
if (bias_tensor && biasPlaceholder) {
621631
id<MTLBuffer> bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias");
622632

623633
NSArray<NSNumber*>* biasShape = @[@(C_out)];
624-
MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
625-
shape:biasShape
626-
dataType:mps_dtype];
634+
biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
635+
shape:biasShape
636+
dataType:mps_dtype];
627637

628638
feeds[biasPlaceholder] = biasData;
629639
ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds");
@@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution(
650660

651661
@try {
652662
// Use stream helper to encode and synchronize correctly
653-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
663+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
654664
} @catch (NSException *exception) {
655665
ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s",
656666
[[exception name] UTF8String], [[exception reason] UTF8String]);
@@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution(
743753
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
744754
memory_to_n_tensor[tensor_data] = 1;
745755

756+
// Release MPSGraph to prevent memory leak
757+
[mpsGraph release];
758+
mpsGraph = nil;
759+
760+
[inputData release];
761+
[weightData release];
762+
if (biasData) [biasData release];
763+
[outputData release];
764+
746765
ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);
747766

748767
ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully");
@@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
9921011
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention");
9931012

9941013
@try {
995-
// Check if scaledDotProductAttentionWithQueryTensor is available
996-
MPSGraph* testGraph = [MPSGraph new];
997-
if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) {
998-
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system");
999-
throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system");
1000-
}
1001-
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available");
1002-
10031014
// Create MPSGraph for scaled dot product attention
10041015
MPSGraph* mpsGraph = [MPSGraph new];
10051016
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance");
@@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12461257
feeds[valuePlaceholder] = valueData;
12471258
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds");
12481259

1260+
MPSGraphTensorData* maskData = nil;
1261+
12491262
// Add explicit mask data to feeds if provided
12501263
if (explicitMaskPlaceholder && attn_mask && *attn_mask) {
12511264
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);
@@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12571270
[maskShapeArray addObject:@(mask_tensor->sizes()[i])];
12581271
}
12591272

1260-
MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
1261-
shape:maskShapeArray
1262-
dataType:mps_dtype];
1273+
maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
1274+
shape:maskShapeArray
1275+
dataType:mps_dtype];
12631276
feeds[explicitMaskPlaceholder] = maskData;
12641277
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds");
12651278
}
@@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12751288

12761289
// Execute via shared stream and keep results on GPU
12771290
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream");
1278-
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
1291+
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
12791292
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully");
12801293

1294+
// Release MPSGraph to prevent memory leak
1295+
[mpsGraph release];
1296+
mpsGraph = nil;
1297+
1298+
[queryData release];
1299+
[keyData release];
1300+
[valueData release];
1301+
if (maskData) [maskData release];
1302+
[outputData release];
1303+
12811304
} @catch (NSException *exception) {
12821305
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s",
12831306
[[exception name] UTF8String], [[exception reason] UTF8String]);

0 commit comments

Comments
 (0)