-
Notifications
You must be signed in to change notification settings - Fork 796
Metal backend: enable Parakeet #16562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16562
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Cancelled Job, 1 Unrelated FailureAs of commit 69398ff with merge base dbf3c37 ( CANCELLED JOB - The following job was cancelled. Please retry:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables support for the Parakeet model in the Metal backend by implementing the necessary operators, fixing existing issues, and adding export/lowering infrastructure. The changes build upon existing Metal backend functionality to support this specific ASR model.
Changes:
- Added Metal backend support to Parakeet export script with custom linear decomposition to avoid addmm and handle tensor reinterpretation
- Implemented
aoti_torch_mps_bmm_outfor batched matrix multiplication and fixed grouped convolution input channel handling - Enhanced shim layer to support non-zero tensor storage offsets and new tensor handle creation
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/parakeet/export_parakeet_tdt.py | Added Metal backend support with custom decomposition and dtype handling |
| examples/models/parakeet/README.md | Updated documentation for Metal export and runner usage |
| backends/apple/metal/runtime/shims/memory.cpp | Implemented storage offset support and new tensor handle creation |
| backends/apple/metal/runtime/shims/et_metal_ops.mm | Added bmm_out implementation and fixed grouped convolution |
| backends/apple/metal/runtime/shims/et_metal_ops.h | Added bmm_out function declaration |
| backends/apple/metal/runtime/shims/et_metal.mm | Added metal_buffer_nocopy function |
| backends/apple/metal/runtime/shims/et_metal.h | Added metal_buffer_nocopy declaration |
| backends/apple/metal/metal_backend.py | Updated supported fallback kernels list |
| backends/aoti/common_shims.cpp | Improved error messages for stubbed functions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) { | ||
| ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. " | ||
| "Only dense row-major layout supported; transposed/view tensors are unsupported. " | ||
| "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", | ||
| (long long)(M * K), (long long)K, (long long)B, (long long)M, (long long)K, | ||
| self_strides[0], self_strides[1], self_strides[2]); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format specifiers for stride values are incorrect. The strides are int64_t but are being printed with %d instead of %lld. This will cause truncation on 64-bit values and incorrect error messages. The same issue exists for mat2_strides and out_strides error messages.
| if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) { | ||
| ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. " | ||
| "Only dense row-major layout supported; transposed/view tensors are unsupported. " | ||
| "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", | ||
| (long long)(K * N), (long long)N, (long long)B, (long long)K, (long long)N, | ||
| mat2_strides[0], mat2_strides[1], mat2_strides[2]); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format specifiers for stride values should be %lld instead of %d. Strides are int64_t values, not int32_t.
| if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) { | ||
| ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. " | ||
| "Only dense row-major layout supported; transposed/view tensors are unsupported. " | ||
| "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", | ||
| (long long)(M * N), (long long)N, (long long)B, (long long)M, (long long)N, | ||
| out_strides[0], out_strides[1], out_strides[2]); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The format specifiers for stride values should be %lld instead of %d. Strides are int64_t values, not int32_t.
| : memory_to_n_tensor[data_ptr] + 1; | ||
| : memory_to_n_tensor[adjusted_data] + 1; | ||
|
|
||
| ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'successfull' to 'successful'.
| ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); | |
| ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successful"); |
| ? NOT_OWN | ||
| : memory_to_n_tensor[data_ptr] + 1; | ||
|
|
||
| ET_LOG(Debug, "aoti_torch_new_tensor_handle: successfull"); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'successfull' to 'successful'.
| ET_LOG(Debug, "aoti_torch_new_tensor_handle: successfull"); | |
| ET_LOG(Debug, "aoti_torch_new_tensor_handle: successful"); |
| dtype_to_element_size(dtype), | ||
| adjusted_data); | ||
|
|
||
| metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true); |
Copilot
AI
Jan 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use-after-free risk due to incorrect lifetime management of reinterpreted tensors with non-zero storage_offset. In aoti_torch__reinterpret_tensor, adjusted_data is computed from the original tensor's data_ptr (line 471) and a new tensor view is created on that interior pointer (lines 483–488), but the reference counting and buffer mapping treat adjusted_data as an independent allocation: metal_buffer_nocopy wraps adjusted_data in a new MTLBuffer via newBufferWithBytesNoCopy and registers it in ptr_to_mtl_buffer, and memory_to_n_tensor[adjusted_data] is incremented (lines 512–517) rather than sharing the original data_ptr's entry. When the original tensor is later deleted first, aoti_torch_delete_tensor_object will see data_ptr with refcount 1 and free the underlying memory (either by metal_deallocate_buffer for MPS buffers or free for CPU buffers), while the reinterpreted tensor still exists and its MTLBuffer (and mutable_data_ptr() pointing to adjusted_data) continue to be used by Metal ops via get_mtl_buffer / ETMetalKernelFunction::setArg. This leaves adjusted_data and its MTLBuffer pointing at freed memory, so subsequent Metal graph executions or kernel launches can read/write out-of-bounds or freed memory, enabling memory corruption and potential process compromise when models or AOT runtimes use _reinterpret_tensor with non-zero storage offsets.
Severity: HIGH. Confidence: 8
| return partitioner, programs | ||
|
|
||
|
|
||
| def _linear_bias_decomposition(input, weight, bias=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we naturally decompose this, if its not being decomposed its because the aoti backend probably specifies not to decompose it. So we should probably just disable that for metal
| @autoreleasepool { | ||
| try { | ||
| // Convert AOTITensorHandle to ExecutorTorch tensors | ||
| auto out_tensor = reinterpret_cast<Tensor*>(out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: It would be useful to wrap this cast in a helper function since we plan on changing the tensor definition backing the shim.
|
|
||
| // Create cache key for this batched matrix multiplication | ||
| // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag | ||
| // This allows reuse when same BMM shape/dtype is called repeatedly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this happen often with transformers? I would think the shape mostly changes each iteration with growing context length
| (void)device; // Used for validation, consistent with other ops | ||
|
|
||
| // Get Metal buffers for input and output tensors | ||
| id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just sort of curious why do you pass the fn name and arg name, are these keys to a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for logging
| ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device"); | ||
| return Error::Internal; | ||
| } | ||
| (void)device; // Used for validation, consistent with other ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats going on here?
|
|
||
| // Validate tensor dimensions - bmm requires 3-D tensors | ||
| if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) { | ||
| ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the failure is that dims is wrong why is it useful to print the shape. Especially since dims could be too big?
mergennachin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See inline comments
| cmake --workflow --preset llm-release | ||
|
|
||
| # For Metal (macOS) | ||
| cmake --workflow --preset llm-debug-metal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's switch back to llm-release-metal in the instructions for consistency
| length:nbytes | ||
| options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared | ||
| deallocator:nil]; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for subBuffer for nullity/false-y
| ET_LOG( | ||
| Debug, | ||
| "aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", | ||
| data_ptr, | ||
| storage_offset, | ||
| dtype_to_element_size(dtype), | ||
| adjusted_data); | ||
|
|
||
| metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true); | ||
| } | ||
|
|
||
| // Increment the reference count for this memory address only if it is owned | ||
| // by tensor | ||
| memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN | ||
| memory_to_n_tensor[adjusted_data] = | ||
| memory_to_n_tensor[adjusted_data] == NOT_OWN | ||
| ? NOT_OWN | ||
| : memory_to_n_tensor[data_ptr] + 1; | ||
| : memory_to_n_tensor[adjusted_data] + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-pilot review looks like legit:
if (adjusted_data != data_ptr) {
ET_LOG(Debug, "...");
metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true);
memory_to_n_tensor[adjusted_data] = NOT_OWN;
}
if (memory_to_n_tensor[data_ptr] != NOT_OWN) {
memory_to_n_tensor[data_ptr] += 1;
}
| (void)new_handle; | ||
| throw std::runtime_error("Not implemented"); | ||
| return Error::Internal; | ||
| ET_LOG(Debug, "aoti_torch_new_tensor_handle: entered"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you handle both zero and non-zero offset in this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean? aoti_torch_new_tensor_handle doesn't take an offset
This pull request builds on top of #16499 to introduce support for the Parakeet model in the Metal backend. The most important changes are grouped below:
Parakeet export/lowering:
Operator updates:
aoti_torch_mps_bmm_outto support batched matrix multiplication (bmm) in the Metal backendaoti_torch_mps_convolutionby reading the correct dimension from the weight tensor.Shim layer updates:
aoti_torch_new_tensor_handleaoti_torch__reinterpret_tensorby adjusting the data pointer instead of rejecting non-zero offsets, and updating memory tracking and Metal buffer mapping logic accordingly.metal_buffer_nocopyfunction to map arbitrary memory pointers into Metal buffers, supporting cases where the data pointer is offset.