Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions backends/apple/metal/runtime/shims/et_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev

// Commit methods
void ETMetalStream::commit() {
if (enableCommitAndContinue_ && commandBuffer_) {
// Use commit-and-continue for better performance
commitAndContinue();
} else {
flush();
if (!commandBuffer_) {
ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit");
return;
}

[commandBuffer_ commit];
ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_);

[commandBuffer_ release];
commandBuffer_ = nil;
}

void ETMetalStream::commitAndWait() {
Expand Down
57 changes: 40 additions & 17 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out(

@try {
// Use stream helper to encode and synchronize correctly
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
} @catch (NSException *exception) {
ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s",
[[exception name] UTF8String], [[exception reason] UTF8String]);
Expand All @@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out(

ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully");

// Release MPSGraph to prevent memory leak
[mpsGraph release];
mpsGraph = nil;

[selfData release];
[mat2Data release];
[outputData release];

ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully");
return Error::Ok;

Expand Down Expand Up @@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution(
feeds[inputPlaceholder] = inputData;
feeds[weightPlaceholder] = weightData;

MPSGraphTensorData* biasData = nil;

// Add bias data to feeds if provided
if (bias_tensor && biasPlaceholder) {
id<MTLBuffer> bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias");

NSArray<NSNumber*>* biasShape = @[@(C_out)];
MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
shape:biasShape
dataType:mps_dtype];
biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer
shape:biasShape
dataType:mps_dtype];

feeds[biasPlaceholder] = biasData;
ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds");
Expand All @@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution(

@try {
// Use stream helper to encode and synchronize correctly
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
} @catch (NSException *exception) {
ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s",
[[exception name] UTF8String], [[exception reason] UTF8String]);
Expand Down Expand Up @@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution(
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
memory_to_n_tensor[tensor_data] = 1;

// Release MPSGraph to prevent memory leak
[mpsGraph release];
mpsGraph = nil;

[inputData release];
[weightData release];
if (biasData) [biasData release];
[outputData release];

ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);

ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully");
Expand Down Expand Up @@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention");

@try {
// Check if scaledDotProductAttentionWithQueryTensor is available
MPSGraph* testGraph = [MPSGraph new];
if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system");
throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system");
}
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available");

// Create MPSGraph for scaled dot product attention
MPSGraph* mpsGraph = [MPSGraph new];
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance");
Expand Down Expand Up @@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
feeds[valuePlaceholder] = valueData;
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds");

MPSGraphTensorData* maskData = nil;

// Add explicit mask data to feeds if provided
if (explicitMaskPlaceholder && attn_mask && *attn_mask) {
auto* mask_tensor = reinterpret_cast<Tensor*>(*attn_mask);
Expand All @@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
[maskShapeArray addObject:@(mask_tensor->sizes()[i])];
}

MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
shape:maskShapeArray
dataType:mps_dtype];
maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer
shape:maskShapeArray
dataType:mps_dtype];
feeds[explicitMaskPlaceholder] = maskData;
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds");
}
Expand All @@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(

// Execute via shared stream and keep results on GPU
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream");
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE);
stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT);
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully");

// Release MPSGraph to prevent memory leak
[mpsGraph release];
mpsGraph = nil;

[queryData release];
[keyData release];
[valueData release];
if (maskData) [maskData release];
[outputData release];

} @catch (NSException *exception) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s",
[[exception name] UTF8String], [[exception reason] UTF8String]);
Expand Down
Loading