diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abfde86db6d..7c88e4cfb5b 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -218,7 +218,7 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) { (void)tensor; (void)ret_size; - throw std::runtime_error("Not implemented"); + throw std::runtime_error("Not implemented: aoti_torch_get_storage_size"); return Error::Internal; } @@ -226,7 +226,8 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) { (void)self; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "Not implemented: aoti_torch_clone_preserve_strides"); return Error::Internal; } @@ -234,7 +235,7 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) { (void)self; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error("Not implemented: aoti_torch_clone"); return Error::Internal; } @@ -257,7 +258,8 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( (void)device_type; (void)device_index; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "Not implemented: aoti_torch_create_tensor_from_blob"); return Error::Internal; } diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1d86cfb8447..fde0410cca3 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -31,7 +31,7 @@ def get_device_name(cls) -> str: @classmethod def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { - "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_bmm_out": None, "aoti_torch_mps_convolution": None, "aoti_torch_mps_mm_out": None, "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index 1c61499b242..e4d71fed72e 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -379,6 +379,7 @@ int metal_copy_memory( bool src_is_device, bool dst_is_device); void metal_cleanup_resources(); +void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer); // Helper functions to access Metal objects MTLDevice_t get_metal_device(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index f7d37c152ce..4f4464a534c 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -113,6 +113,18 @@ void metal_cleanup_resources() { } } +void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer) { + id device = get_metal_device(); + id subBuffer = [device newBufferWithBytesNoCopy:ptr + length:nbytes + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared + deallocator:nil]; + + if (map_ptr_to_buffer) { + ptr_to_mtl_buffer[ptr] = subBuffer; // Map contents to buffer + } +} + bool metal_is_device_pointer(void* ptr) { return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); } diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h index 78bdb419ea4..fcc6dfc03da 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.h +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -27,6 +27,16 @@ AOTITorchError aoti_torch_mps_mm_out( AOTITensorHandle self, AOTITensorHandle mat2); +/** + * ExecutorTorch implementation of aoti_torch_mps_bmm_out. + * Performs batched matrix multiplication: out = self @ mat2 + * All tensors must be 3-D with matching batch dimensions. + */ +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + /** * ExecutorTorch implementation of aoti_torch_mps_convolution. * Performs 2D convolution operation - matches PyTorch AOTI signature diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index da54dafb334..5b413728de5 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -626,6 +626,316 @@ AOTITorchError aoti_torch_mps_mm_out( } } +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + + // Validate non-null handles + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + // Validate tensor dimensions - bmm requires 3-D tensors + if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. " + "Got self.dim=%zd (shape=[%d,%d,%d]), " + "mat2.dim=%zd (shape=[%d,%d,%d]), " + "out.dim=%zd (shape=[%d,%d,%d])", + self_tensor->dim(), + self_tensor->dim() > 0 ? (int)self_tensor->sizes()[0] : 0, + self_tensor->dim() > 1 ? (int)self_tensor->sizes()[1] : 0, + self_tensor->dim() > 2 ? (int)self_tensor->sizes()[2] : 0, + mat2_tensor->dim(), + mat2_tensor->dim() > 0 ? (int)mat2_tensor->sizes()[0] : 0, + mat2_tensor->dim() > 1 ? (int)mat2_tensor->sizes()[1] : 0, + mat2_tensor->dim() > 2 ? (int)mat2_tensor->sizes()[2] : 0, + out_tensor->dim(), + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0, + out_tensor->dim() > 2 ? (int)out_tensor->sizes()[2] : 0); + return Error::InvalidArgument; + } + + int64_t B = self_tensor->sizes()[0]; // batch size + int64_t M = self_tensor->sizes()[1]; // rows of self + int64_t K = self_tensor->sizes()[2]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[2]; // cols of mat2 + + // Validate shape constraints + // self: [B, M, K], mat2: [B, K, N], out: [B, M, N] + if (mat2_tensor->sizes()[0] != B) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: batch size mismatch. " + "Expected mat2[0]=%d to match self[0]=%lld. " + "self.shape=[%lld,%lld,%lld], mat2.shape=[%d,%d,%d]", + (int)mat2_tensor->sizes()[0], (long long)B, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (mat2_tensor->sizes()[1] != K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. " + "Expected mat2[1]=%d to match self[2]=%lld. " + "Cannot multiply [%lld,%lld,%lld] @ [%d,%d,%d]", + (int)mat2_tensor->sizes()[1], (long long)K, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (out_tensor->sizes()[0] != B || out_tensor->sizes()[1] != M || out_tensor->sizes()[2] != N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: output shape mismatch. " + "Expected out.shape=[%lld,%lld,%lld], got [%d,%d,%d]", + (long long)B, (long long)M, (long long)N, + (int)out_tensor->sizes()[0], (int)out_tensor->sizes()[1], (int)out_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + // Validate dtype consistency + int32_t self_dtype = static_cast(self_tensor->scalar_type()); + int32_t mat2_dtype = static_cast(mat2_tensor->scalar_type()); + int32_t out_dtype = static_cast(out_tensor->scalar_type()); + + if (self_dtype != mat2_dtype || self_dtype != out_dtype) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: dtype mismatch. " + "All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d", + self_dtype, mat2_dtype, out_dtype); + return Error::InvalidArgument; + } + + int32_t dtype = self_dtype; + + // Validate layout: BMM requires strictly contiguous 3D tensors + // For shape [B, M, K], contiguous strides MUST be [M*K, K, 1] + // + // Why strict contiguity is required: + // - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer + // as containing dense row-major data for the given shape + // - Non-contiguous layouts (transposed, views with strides, etc.) have different + // memory layouts that don't match what MPS expects + // - This would result in SILENT WRONG RESULTS + // - This is an _out op: we must NOT create implicit copies + // - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported) + // + // Limitation: This implementation does not explicitly check storage offset (no API available). + // Tensors with non-zero storage offsets are not explicitly rejected but may work if they + // happen to have contiguous strides. Users should ensure tensors are base tensors without offsets. + auto self_strides = self_tensor->strides(); + auto mat2_strides = mat2_tensor->strides(); + auto out_strides = out_tensor->strides(); + + // Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1] + if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * K), (long long)K, (long long)B, (long long)M, (long long)K, + self_strides[0], self_strides[1], self_strides[2]); + return Error::InvalidArgument; + } + + // Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1] + if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(K * N), (long long)N, (long long)B, (long long)K, (long long)N, + mat2_strides[0], mat2_strides[1], mat2_strides[2]); + return Error::InvalidArgument; + } + + // Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1] + if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * N), (long long)N, (long long)B, (long long)M, (long long)N, + out_strides[0], out_strides[1], out_strides[2]); + return Error::InvalidArgument; + } + + // Get Metal stream and device + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device"); + return Error::Internal; + } + (void)device; // Used for validation, consistent with other ops + + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_bmm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_bmm_out", "out"); + + // Validate buffers are non-null + if (!self_buffer || !mat2_buffer || !out_buffer) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal buffers. " + "self_buffer=%p, mat2_buffer=%p, out_buffer=%p", + self_buffer, mat2_buffer, out_buffer); + return Error::Internal; + } + + // End any existing kernel coalescing to ensure clean state + // (consistent with mm_out and conv pattern) + stream->endKernelCoalescing(); + + // Map dtype to MPS type and validate support + // Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h) + // FLOAT16 is not in SupportedDTypes enum and is not supported + MPSDataType mps_dtype; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + } else { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Unsupported data type: %d. " + "Supported types: FLOAT32 (%d), BFLOAT16 (%d)", + dtype, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; + } + + // Define shapes for graph placeholders and tensor data + NSArray* selfShape = @[@(B), @(M), @(K)]; + NSArray* mat2Shape = @[@(B), @(K), @(N)]; + NSArray* outShape = @[@(B), @(M), @(N)]; + + // Create cache key for this batched matrix multiplication + // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag + // This allows reuse when same BMM shape/dtype is called repeatedly + GraphCacheKey cache_key; + cache_key.op_name = "bmm"; + cache_key.shape_params = {B, M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = false; // BMM has no transpose handling + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + outputTensor = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + + // Create 3D placeholders for batched matrices + // These represent the logical shapes for the batched matrix multiplication + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + dataType:mps_dtype + name:@"mat2"]; + + // MPSGraph matrixMultiplication handles batched case natively when given 3D tensors + // For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N] + outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Placeholder + name:@"bmm_result"]; + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; // No third input for BMM + cached_graph.output = outputTensor; + graph_cache[cache_key] = cached_graph; + + } // End of cache miss/hit block + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + // These wrap the MTLBuffers with the shape information + // Initialize to nil for safe cleanup in exception path + MPSGraphTensorData* selfData = nil; + MPSGraphTensorData* mat2Data = nil; + MPSGraphTensorData* outputData = nil; + + selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2Shape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + // Create output tensor data + outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + // Build results dictionary + NSDictionary* results = @{ + outputTensor: outputData + }; + + // Execute the batched matrix multiplication + @try { + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + // Guard releases against nil + if (selfData) [selfData release]; + if (mat2Data) [mat2Data release]; + if (outputData) [outputData release]; + return Error::Internal; + } + + // Release MPSGraphTensorData objects + [selfData release]; + [mat2Data release]; + [outputData release]; + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_bmm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: unknown exception"); + return Error::Internal; + } + } +} + AOTITorchError aoti_torch_mps_convolution( AOTITensorHandle input, AOTITensorHandle weight, @@ -827,10 +1137,13 @@ AOTITorchError aoti_torch_mps_convolution( } ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + // Get weight's input channel dimension from the weight tensor (not from input) + // For grouped convolutions, weight shape is [C_out, C_in/groups, kH, kW] + int64_t weight_C_in = weight_tensor->sizes()[1]; // This handles grouped convs correctly // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; - NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + NSArray* weightShape = @[@(C_out), @(weight_C_in), @(kernel_h), @(kernel_w)]; // Create cache key for this convolution GraphCacheKey cache_key; diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index ebb5b7642e1..eae8e62beef 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -430,9 +430,6 @@ AOTITorchError aoti_torch__reinterpret_tensor( InvalidArgument, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); - // Check if storage_offset is not 0 - return error if not - ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); - // Get the device info from the source tensor to perform device_index // validation int32_t device_type = 0; @@ -470,6 +467,10 @@ AOTITorchError aoti_torch__reinterpret_tensor( "Memory address %p is not being tracked by reference counting system", data_ptr); + // Handle storage offset by adjusting the data pointer + void* adjusted_data = static_cast(data_ptr) + + (storage_offset * dtype_to_element_size(dtype)); + // Convert sizes using utility function from utils.h std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); @@ -480,7 +481,7 @@ AOTITorchError aoti_torch__reinterpret_tensor( // Create new tensor view that reinterprets the same memory with different // shape/strides This creates a view, not a copy - the data pointer is shared std::shared_ptr tensor = executorch::extension::from_blob( - data_ptr, // Reuse the same memory from source tensor + adjusted_data, // Use adjusted data pointer with storage offset applied sizes, // New sizes with explicit SizesType strides, // New strides with explicit StridesType dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting @@ -496,11 +497,24 @@ AOTITorchError aoti_torch__reinterpret_tensor( *ret_new_tensor = tensor.get(); + if (adjusted_data != data_ptr) { + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data_ptr, + storage_offset, + dtype_to_element_size(dtype), + adjusted_data); + + metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true); + } + // Increment the reference count for this memory address only if it is owned // by tensor - memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + memory_to_n_tensor[adjusted_data] = + memory_to_n_tensor[adjusted_data] == NOT_OWN ? NOT_OWN - : memory_to_n_tensor[data_ptr] + 1; + : memory_to_n_tensor[adjusted_data] + 1; ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); return Error::Ok; @@ -509,10 +523,92 @@ AOTITorchError aoti_torch__reinterpret_tensor( AOTITorchError aoti_torch_new_tensor_handle( Tensor* orig_handle, Tensor** new_handle) { - (void)orig_handle; - (void)new_handle; - throw std::runtime_error("Not implemented"); - return Error::Internal; + ET_LOG(Debug, "aoti_torch_new_tensor_handle: entered"); + + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + orig_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: orig_handle is null"); + + ET_CHECK_OR_RETURN_ERROR( + new_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: new_handle is null"); + + // Get metadata from the original tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + int32_t dtype; + int32_t device_type; + int32_t device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_strides(orig_handle, &strides_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_type(orig_handle, &device_type)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_index(orig_handle, &device_index)); + + int64_t ndim = orig_handle->dim(); + + // Validate dtype + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the original data pointer from the source tensor + void* data_ptr = orig_handle->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes and strides to vectors + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor that shares the same memory as the original + // This is similar to PyTorch's Tensor copy constructor - creates a new + // tensor object that shares the same underlying storage + std::shared_ptr tensor = executorch::extension::from_blob( + data_ptr, // Share the same memory from source tensor + sizes, // Same sizes as original + strides, // Same strides as original + dtype_to_scalar_type(dtype) // Same dtype as original + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create new tensor handle"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *new_handle = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + ET_LOG(Debug, "aoti_torch_new_tensor_handle: successfull"); + return Error::Ok; } // Cleanup function for clearing global state diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..7a611c90e82 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,26 +25,47 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `cuda`, `cuda-windows` (default: `portable`) | +| `--dtype` | Data type: `fp32`, `bf16` (default: `fp32`). Metal backend supports fp32 and bf16 only. | | `--audio` | Path to audio file for transcription test | **Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. +### Metal Export (macOS) + +```bash +python export_parakeet_tdt.py --backend metal --output-dir ./parakeet_metal +``` + +This generates: +- `parakeet_tdt.pte` - The compiled model +- `aoti_metal_blob.ptd` - Metal kernel blob required at runtime +- `tokenizer.model` - SentencePiece tokenizer + ## C++ Runner ### Building -First, build ExecuTorch with the LLM preset from the executorch root directory: +First, build ExecuTorch with the appropriate preset from the executorch root directory: ```bash +# For CPU/XNNPACK cmake --workflow --preset llm-release + +# For Metal (macOS) +cmake --workflow --preset llm-debug-metal ``` Then build the parakeet runner: ```bash cd examples/models/parakeet + +# CPU/XNNPACK build cmake --workflow --preset parakeet-cpu + +# Metal build +cmake --workflow --preset parakeet-metal ``` Available presets: @@ -57,10 +78,18 @@ Available presets: From the executorch root directory: ```bash +# CPU/XNNPACK ./cmake-out/examples/models/parakeet/parakeet_runner \ --model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \ --audio_path /path/to/audio.wav \ --tokenizer_path examples/models/parakeet/parakeet_tdt_exports/tokenizer.model + +# Metal (include .ptd data file) +DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_metal/parakeet_tdt.pte \ + --data_path examples/models/parakeet/parakeet_metal/aoti_metal_blob.ptd \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/parakeet_metal/tokenizer.model ``` ### Runner Arguments @@ -70,4 +99,4 @@ From the executorch root directory: | `--model_path` | Path to Parakeet model (.pte) | | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | -| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--data_path` | Path to data file (.ptd) for delegate data (required for Metal/CUDA) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..a53f6920e3e 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -7,6 +7,7 @@ import tempfile import torch + import torchaudio from executorch.exir import ( EdgeCompileConfig, @@ -363,48 +364,102 @@ def export_all(model): return programs, metadata -def lower_to_executorch(programs, metadata=None, backend="portable"): +def _create_xnnpack_partitioners(programs): + """Create XNNPACK partitioners for all programs except preprocessor.""" + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + + print("\nLowering to ExecuTorch with XNNPACK...") partitioner = {} + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + partitioner[key] = [XnnpackPartitioner()] + return partitioner, programs + + +def _linear_bias_decomposition(input, weight, bias=None): + """Decompose linear with bias into matmul + add.""" + # linear(input, weight) = input @ weight.T + # Use matmul instead of mm to handle batched inputs (3D+) + weight_t = torch.ops.aten.t.default(weight) + out = torch.ops.aten.matmul.default(input, weight_t) + if bias is not None: + return torch.ops.aten.add.Tensor(out, bias) + return out + + +def _create_metal_partitioners(programs): + """Create Metal partitioners for all programs except preprocessor.""" + from executorch.backends.apple.metal.metal_backend import MetalBackend + from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner + + print("\nLowering to ExecuTorch with Metal...") + + # Run decompositions for non-preprocessor programs + updated_programs = {} + for key, ep in programs.items(): + # print(f"Running decompositions for {key}") + # print(ep.graph_module) + if key != "preprocessor": + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.linear.default: _linear_bias_decomposition} + ) + else: + updated_programs[key] = ep - if backend == "xnnpack": - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, - ) + partitioner = {} + for key in updated_programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] + partitioner[key] = [MetalPartitioner(compile_specs)] + return partitioner, updated_programs + + +def _create_cuda_partitioners(programs, is_windows=False): + """Create CUDA partitioners for all programs except preprocessor.""" + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch._inductor.decomposition import conv1d_to_conv2d + + print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if is_windows else ''}...") + + # Run decompositions for non-preprocessor programs + updated_programs = {} + for key, ep in programs.items(): + if key != "preprocessor": + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) + else: + updated_programs[key] = ep - print("\nLowering to ExecuTorch with XNNPACK...") - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - partitioner[key] = [XnnpackPartitioner()] + partitioner = {} + for key in updated_programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if is_windows: + compile_specs.append(CompileSpec("platform", "windows".encode("utf-8"))) + partitioner[key] = [CudaPartitioner(compile_specs)] + return partitioner, updated_programs - elif backend in ("cuda", "cuda-windows"): - from executorch.backends.cuda.cuda_backend import CudaBackend - from executorch.backends.cuda.cuda_partitioner import CudaPartitioner - from executorch.exir.backend.compile_spec_schema import CompileSpec - from torch._inductor.decomposition import conv1d_to_conv2d - print( - f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}..." +def lower_to_executorch(programs, metadata=None, backend="portable"): + if backend == "xnnpack": + partitioner, programs = _create_xnnpack_partitioners(programs) + elif backend == "metal": + partitioner, programs = _create_metal_partitioners(programs) + elif backend in ("cuda", "cuda-windows"): + partitioner, programs = _create_cuda_partitioners( + programs, is_windows=(backend == "cuda-windows") ) - - for key, ep in programs.items(): - if key != "preprocessor": - programs[key] = ep.run_decompositions( - {torch.ops.aten.conv1d.default: conv1d_to_conv2d} - ) - - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] - if backend == "cuda-windows": - compile_specs.append( - CompileSpec("platform", "windows".encode("utf-8")) - ) - partitioner[key] = [CudaPartitioner(compile_specs)] - else: print("\nLowering to ExecuTorch...") partitioner = [] @@ -442,11 +497,22 @@ def main(): "--backend", type=str, default="portable", - choices=["portable", "xnnpack", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], help="Backend for acceleration (default: portable)", ) + parser.add_argument( + "--dtype", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16"], + help="Model dtype for Metal/CUDA backends (default: fp32)", + ) args = parser.parse_args() + # Validate dtype for Metal backend + if args.backend == "metal" and args.dtype == "fp16": + parser.error("Metal backend only supports fp32 and bf16, not fp16") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") @@ -455,6 +521,14 @@ def main(): print("Loading model...") model = load_model() + # Convert model to specified dtype for Metal/CUDA backends + if args.dtype == "bf16": + print("Converting model to bfloat16...") + model = model.to(torch.bfloat16) + elif args.dtype == "fp16": + print("Converting model to float16...") + model = model.to(torch.float16) + print("\nExporting components...") programs, metadata = export_all(model)