diff --git a/CMakePresets.json b/CMakePresets.json index 2b1512ac121..c265e28e6c8 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -183,7 +183,9 @@ ], "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out" + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Info" } }, { diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1d86cfb8447..a6fb516c948 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -32,6 +32,7 @@ def get_device_name(cls) -> str: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_bmm_out": None, "aoti_torch_mps_convolution": None, "aoti_torch_mps_mm_out": None, "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h index 78bdb419ea4..a5aca7a427b 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.h +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -27,6 +27,28 @@ AOTITorchError aoti_torch_mps_mm_out( AOTITensorHandle self, AOTITensorHandle mat2); +/** + * ExecutorTorch implementation of aoti_torch_mps_addmm_out. + * Performs: out = beta * input + alpha * (mat1 @ mat2) + */ +AOTITorchError aoti_torch_mps_addmm_out( + AOTITensorHandle out, + AOTITensorHandle input, + AOTITensorHandle mat1, + AOTITensorHandle mat2, + int64_t beta, + int64_t alpha); + +/** + * ExecutorTorch implementation of aoti_torch_mps_bmm_out. + * Performs batched matrix multiplication: out = self @ mat2 + * All tensors must be 3-D with matching batch dimensions. + */ +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + /** * ExecutorTorch implementation of aoti_torch_mps_convolution. * Performs 2D convolution operation - matches PyTorch AOTI signature diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index da54dafb334..88d79e4275f 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -97,14 +97,18 @@ void logStats() { static CacheStats cache_stats; // Helper function to get Metal buffer from the global mapping +// All tensors should have their data_ptr directly in the map since we materialize +// all offset/broadcast cases in aoti_torch__reinterpret_tensor static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { void* data_ptr = tensor->mutable_data_ptr(); + auto it = ptr_to_mtl_buffer.find(data_ptr); - if (it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name); - throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping"); + if (it != ptr_to_mtl_buffer.end()) { + return it->second; } - return it->second; + + ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name); + throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping"); } // Helper function to allocate a Metal buffer and register it in the global mapping. @@ -626,6 +630,489 @@ AOTITorchError aoti_torch_mps_mm_out( } } +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + + // Validate non-null handles + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + // Validate tensor dimensions - bmm requires 3-D tensors + if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. " + "Got self.dim=%zd (shape=[%d,%d,%d]), " + "mat2.dim=%zd (shape=[%d,%d,%d]), " + "out.dim=%zd (shape=[%d,%d,%d])", + self_tensor->dim(), + self_tensor->dim() > 0 ? (int)self_tensor->sizes()[0] : 0, + self_tensor->dim() > 1 ? (int)self_tensor->sizes()[1] : 0, + self_tensor->dim() > 2 ? (int)self_tensor->sizes()[2] : 0, + mat2_tensor->dim(), + mat2_tensor->dim() > 0 ? (int)mat2_tensor->sizes()[0] : 0, + mat2_tensor->dim() > 1 ? (int)mat2_tensor->sizes()[1] : 0, + mat2_tensor->dim() > 2 ? (int)mat2_tensor->sizes()[2] : 0, + out_tensor->dim(), + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0, + out_tensor->dim() > 2 ? (int)out_tensor->sizes()[2] : 0); + return Error::InvalidArgument; + } + + int64_t B = self_tensor->sizes()[0]; // batch size + int64_t M = self_tensor->sizes()[1]; // rows of self + int64_t K = self_tensor->sizes()[2]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[2]; // cols of mat2 + + // Validate shape constraints + // self: [B, M, K], mat2: [B, K, N], out: [B, M, N] + if (mat2_tensor->sizes()[0] != B) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: batch size mismatch. " + "Expected mat2[0]=%d to match self[0]=%lld. " + "self.shape=[%lld,%lld,%lld], mat2.shape=[%d,%d,%d]", + (int)mat2_tensor->sizes()[0], (long long)B, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (mat2_tensor->sizes()[1] != K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. " + "Expected mat2[1]=%d to match self[2]=%lld. " + "Cannot multiply [%lld,%lld,%lld] @ [%d,%d,%d]", + (int)mat2_tensor->sizes()[1], (long long)K, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (out_tensor->sizes()[0] != B || out_tensor->sizes()[1] != M || out_tensor->sizes()[2] != N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: output shape mismatch. " + "Expected out.shape=[%lld,%lld,%lld], got [%d,%d,%d]", + (long long)B, (long long)M, (long long)N, + (int)out_tensor->sizes()[0], (int)out_tensor->sizes()[1], (int)out_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + // Validate dtype consistency + int32_t self_dtype = static_cast(self_tensor->scalar_type()); + int32_t mat2_dtype = static_cast(mat2_tensor->scalar_type()); + int32_t out_dtype = static_cast(out_tensor->scalar_type()); + + if (self_dtype != mat2_dtype || self_dtype != out_dtype) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: dtype mismatch. " + "All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d", + self_dtype, mat2_dtype, out_dtype); + return Error::InvalidArgument; + } + + int32_t dtype = self_dtype; + + // Validate layout: BMM requires strictly contiguous 3D tensors + // For shape [B, M, K], contiguous strides MUST be [M*K, K, 1] + // + // Why strict contiguity is required: + // - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer + // as containing dense row-major data for the given shape + // - Non-contiguous layouts (transposed, views with strides, etc.) have different + // memory layouts that don't match what MPS expects + // - This would result in SILENT WRONG RESULTS + // - This is an _out op: we must NOT create implicit copies + // - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported) + // + // Limitation: This implementation does not explicitly check storage offset (no API available). + // Tensors with non-zero storage offsets are not explicitly rejected but may work if they + // happen to have contiguous strides. Users should ensure tensors are base tensors without offsets. + auto self_strides = self_tensor->strides(); + auto mat2_strides = mat2_tensor->strides(); + auto out_strides = out_tensor->strides(); + + // Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1] + if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * K), (long long)K, (long long)B, (long long)M, (long long)K, + self_strides[0], self_strides[1], self_strides[2]); + return Error::InvalidArgument; + } + + // Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1] + if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(K * N), (long long)N, (long long)B, (long long)K, (long long)N, + mat2_strides[0], mat2_strides[1], mat2_strides[2]); + return Error::InvalidArgument; + } + + // Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1] + if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * N), (long long)N, (long long)B, (long long)M, (long long)N, + out_strides[0], out_strides[1], out_strides[2]); + return Error::InvalidArgument; + } + + // Get Metal stream and device + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device"); + return Error::Internal; + } + (void)device; // Used for validation, consistent with other ops + + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_bmm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_bmm_out", "out"); + + // Validate buffers are non-null + if (!self_buffer || !mat2_buffer || !out_buffer) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal buffers. " + "self_buffer=%p, mat2_buffer=%p, out_buffer=%p", + self_buffer, mat2_buffer, out_buffer); + return Error::Internal; + } + + // End any existing kernel coalescing to ensure clean state + // (consistent with mm_out and conv pattern) + stream->endKernelCoalescing(); + + // Map dtype to MPS type and validate support + // Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h) + // FLOAT16 is not in SupportedDTypes enum and is not supported + MPSDataType mps_dtype; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + } else { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Unsupported data type: %d. " + "Supported types: FLOAT32 (%d), BFLOAT16 (%d)", + dtype, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; + } + + // Define shapes for graph placeholders and tensor data + NSArray* selfShape = @[@(B), @(M), @(K)]; + NSArray* mat2Shape = @[@(B), @(K), @(N)]; + NSArray* outShape = @[@(B), @(M), @(N)]; + + // Create cache key for this batched matrix multiplication + // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag + // This allows reuse when same BMM shape/dtype is called repeatedly + GraphCacheKey cache_key; + cache_key.op_name = "bmm"; + cache_key.shape_params = {B, M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = false; // BMM has no transpose handling + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + outputTensor = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + + // Create 3D placeholders for batched matrices + // These represent the logical shapes for the batched matrix multiplication + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + dataType:mps_dtype + name:@"mat2"]; + + // MPSGraph matrixMultiplication handles batched case natively when given 3D tensors + // For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N] + outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Placeholder + name:@"bmm_result"]; + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; // No third input for BMM + cached_graph.output = outputTensor; + graph_cache[cache_key] = cached_graph; + + } // End of cache miss/hit block + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + // These wrap the MTLBuffers with the shape information + // Initialize to nil for safe cleanup in exception path + MPSGraphTensorData* selfData = nil; + MPSGraphTensorData* mat2Data = nil; + MPSGraphTensorData* outputData = nil; + + selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2Shape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + // Create output tensor data + outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + // Build results dictionary + NSDictionary* results = @{ + outputTensor: outputData + }; + + // Execute the batched matrix multiplication + @try { + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + // Guard releases against nil + if (selfData) [selfData release]; + if (mat2Data) [mat2Data release]; + if (outputData) [outputData release]; + return Error::Internal; + } + + // Release MPSGraphTensorData objects + [selfData release]; + [mat2Data release]; + [outputData release]; + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_bmm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_addmm_out( + AOTITensorHandle out, + AOTITensorHandle input, + AOTITensorHandle mat1, + AOTITensorHandle mat2, + int64_t beta, + int64_t alpha) { + + if (!out || !input || !mat1 || !mat2) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto input_tensor = reinterpret_cast(input); + auto mat1_tensor = reinterpret_cast(mat1); + auto mat2_tensor = reinterpret_cast(mat2); + + // Validate tensor dimensions + if (mat1_tensor->dim() != 2 || mat2_tensor->dim() != 2) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: mat1 and mat2 must be 2-D, got %d and %d", + (int)mat1_tensor->dim(), (int)mat2_tensor->dim()); + return Error::InvalidArgument; + } + + int64_t M = mat1_tensor->sizes()[0]; + int64_t K = mat1_tensor->sizes()[1]; + int64_t N = mat2_tensor->sizes()[1]; + + if (mat1_tensor->sizes()[1] != mat2_tensor->sizes()[0]) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: incompatible matrix sizes"); + return Error::InvalidArgument; + } + + // Get Metal stream + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: Failed to get Metal device"); + return Error::Internal; + } + + // Get Metal buffers + id input_buffer = get_mtl_buffer(input_tensor, "aoti_torch_mps_addmm_out", "input"); + id mat1_buffer = get_mtl_buffer(mat1_tensor, "aoti_torch_mps_addmm_out", "mat1"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_addmm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_addmm_out", "out"); + + stream->endKernelCoalescing(); + + // Determine data type + int32_t dtype = static_cast(mat1_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); + } else { + ET_LOG(Error, "aoti_torch_mps_addmm_out: Unsupported data type: %d", dtype); + return Error::InvalidArgument; + } + + // Create MPSGraph for addmm: out = beta * input + alpha * (mat1 @ mat2) + MPSGraph* mpsGraph = [MPSGraph new]; + + NSArray* mat1Shape = @[@(M), @(K)]; + NSArray* mat2Shape = @[@(K), @(N)]; + NSArray* outShape = @[@(M), @(N)]; + + // Handle input shape - it could be 1D (bias) or 2D + NSArray* inputShape; + if (input_tensor->dim() == 1) { + inputShape = @[@(input_tensor->sizes()[0])]; + } else { + inputShape = @[@(input_tensor->sizes()[0]), @(input_tensor->sizes()[1])]; + } + + MPSGraphTensor* mat1Placeholder = [mpsGraph placeholderWithShape:mat1Shape + dataType:mps_dtype + name:@"mat1"]; + MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + dataType:mps_dtype + name:@"mat2"]; + MPSGraphTensor* inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + + // Compute mat1 @ mat2 + MPSGraphTensor* mmResult = [mpsGraph matrixMultiplicationWithPrimaryTensor:mat1Placeholder + secondaryTensor:mat2Placeholder + name:@"mm"]; + + // Scale mm result by alpha + MPSGraphTensor* alphaConst = [mpsGraph constantWithScalar:(double)alpha + dataType:mps_dtype]; + MPSGraphTensor* scaledMM = [mpsGraph multiplicationWithPrimaryTensor:mmResult + secondaryTensor:alphaConst + name:@"scaled_mm"]; + + // Scale input by beta + MPSGraphTensor* betaConst = [mpsGraph constantWithScalar:(double)beta + dataType:mps_dtype]; + MPSGraphTensor* scaledInput = [mpsGraph multiplicationWithPrimaryTensor:inputPlaceholder + secondaryTensor:betaConst + name:@"scaled_input"]; + + // Add: out = beta * input + alpha * (mat1 @ mat2) + MPSGraphTensor* addResult = [mpsGraph additionWithPrimaryTensor:scaledInput + secondaryTensor:scaledMM + name:@"addmm_result"]; + + // Create tensor data + MPSGraphTensorData* mat1Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat1_buffer + shape:mat1Shape + dataType:mps_dtype]; + MPSGraphTensorData* mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2Shape + dataType:mps_dtype]; + MPSGraphTensorData* inputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:input_buffer + shape:inputShape + dataType:mps_dtype]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + feeds[mat1Placeholder] = mat1Data; + feeds[mat2Placeholder] = mat2Data; + feeds[inputPlaceholder] = inputData; + + NSDictionary* results = @{addResult: outputData}; + + @try { + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: NSException: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + [mat1Data release]; + [mat2Data release]; + [inputData release]; + [outputData release]; + return Error::Internal; + } + + [mat1Data release]; + [mat2Data release]; + [inputData release]; + [outputData release]; + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_addmm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: unknown exception"); + return Error::Internal; + } + } +} + AOTITorchError aoti_torch_mps_convolution( AOTITensorHandle input, AOTITensorHandle weight, @@ -827,10 +1314,13 @@ AOTITorchError aoti_torch_mps_convolution( } ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + // Get weight's input channel dimension from the weight tensor (not from input) + // For grouped convolutions, weight shape is [C_out, C_in/groups, kH, kW] + int64_t weight_C_in = weight_tensor->sizes()[1]; // This handles grouped convs correctly // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; - NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + NSArray* weightShape = @[@(C_out), @(weight_C_in), @(kernel_h), @(kernel_w)]; // Create cache key for this convolution GraphCacheKey cache_key; diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index ea93d257ba7..a915bc643cf 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -34,7 +34,10 @@ "displayName": "Parakeet runner (Metal)", "inherits": ["parakeet-base"], "cacheVariables": { - "EXECUTORCH_BUILD_METAL": "ON" + "CMAKE_BUILD_TYPE": "Debug", + "EXECUTORCH_BUILD_METAL": "ON", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Info" }, "condition": { "lhs": "${hostSystemName}", diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..7a611c90e82 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,26 +25,47 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `cuda`, `cuda-windows` (default: `portable`) | +| `--dtype` | Data type: `fp32`, `bf16` (default: `fp32`). Metal backend supports fp32 and bf16 only. | | `--audio` | Path to audio file for transcription test | **Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. +### Metal Export (macOS) + +```bash +python export_parakeet_tdt.py --backend metal --output-dir ./parakeet_metal +``` + +This generates: +- `parakeet_tdt.pte` - The compiled model +- `aoti_metal_blob.ptd` - Metal kernel blob required at runtime +- `tokenizer.model` - SentencePiece tokenizer + ## C++ Runner ### Building -First, build ExecuTorch with the LLM preset from the executorch root directory: +First, build ExecuTorch with the appropriate preset from the executorch root directory: ```bash +# For CPU/XNNPACK cmake --workflow --preset llm-release + +# For Metal (macOS) +cmake --workflow --preset llm-debug-metal ``` Then build the parakeet runner: ```bash cd examples/models/parakeet + +# CPU/XNNPACK build cmake --workflow --preset parakeet-cpu + +# Metal build +cmake --workflow --preset parakeet-metal ``` Available presets: @@ -57,10 +78,18 @@ Available presets: From the executorch root directory: ```bash +# CPU/XNNPACK ./cmake-out/examples/models/parakeet/parakeet_runner \ --model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \ --audio_path /path/to/audio.wav \ --tokenizer_path examples/models/parakeet/parakeet_tdt_exports/tokenizer.model + +# Metal (include .ptd data file) +DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_metal/parakeet_tdt.pte \ + --data_path examples/models/parakeet/parakeet_metal/aoti_metal_blob.ptd \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/parakeet_metal/tokenizer.model ``` ### Runner Arguments @@ -70,4 +99,4 @@ From the executorch root directory: | `--model_path` | Path to Parakeet model (.pte) | | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | -| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--data_path` | Path to data file (.ptd) for delegate data (required for Metal/CUDA) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..2511869398e 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -7,6 +7,7 @@ import tempfile import torch + import torchaudio from executorch.exir import ( EdgeCompileConfig, @@ -363,48 +364,78 @@ def export_all(model): return programs, metadata -def lower_to_executorch(programs, metadata=None, backend="portable"): - partitioner = {} +def _create_xnnpack_partitioners(programs): + """Create XNNPACK partitioners for all programs except preprocessor.""" + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) - if backend == "xnnpack": - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, - ) + print("\nLowering to ExecuTorch with XNNPACK...") + partitioner = {} + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + partitioner[key] = [XnnpackPartitioner()] + return partitioner, programs - print("\nLowering to ExecuTorch with XNNPACK...") - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - partitioner[key] = [XnnpackPartitioner()] - elif backend in ("cuda", "cuda-windows"): - from executorch.backends.cuda.cuda_backend import CudaBackend - from executorch.backends.cuda.cuda_partitioner import CudaPartitioner - from executorch.exir.backend.compile_spec_schema import CompileSpec - from torch._inductor.decomposition import conv1d_to_conv2d +def _create_metal_partitioners(programs): + """Create Metal partitioners for all programs except preprocessor.""" + from executorch.backends.apple.metal.metal_backend import MetalBackend + from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner - print( - f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}..." - ) + print("\nLowering to ExecuTorch with Metal...") + partitioner = {} + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] + partitioner[key] = [MetalPartitioner(compile_specs)] + return partitioner, programs + + +def _create_cuda_partitioners(programs, is_windows=False): + """Create CUDA partitioners for all programs except preprocessor.""" + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch._inductor.decomposition import conv1d_to_conv2d + + print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if is_windows else ''}...") + + # Run decompositions for non-preprocessor programs + updated_programs = {} + for key, ep in programs.items(): + if key != "preprocessor": + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) + else: + updated_programs[key] = ep - for key, ep in programs.items(): - if key != "preprocessor": - programs[key] = ep.run_decompositions( - {torch.ops.aten.conv1d.default: conv1d_to_conv2d} - ) + partitioner = {} + for key in updated_programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if is_windows: + compile_specs.append(CompileSpec("platform", "windows".encode("utf-8"))) + partitioner[key] = [CudaPartitioner(compile_specs)] + return partitioner, updated_programs - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] - if backend == "cuda-windows": - compile_specs.append( - CompileSpec("platform", "windows".encode("utf-8")) - ) - partitioner[key] = [CudaPartitioner(compile_specs)] +def lower_to_executorch(programs, metadata=None, backend="portable"): + if backend == "xnnpack": + partitioner, programs = _create_xnnpack_partitioners(programs) + elif backend == "metal": + partitioner, programs = _create_metal_partitioners(programs) + elif backend in ("cuda", "cuda-windows"): + partitioner, programs = _create_cuda_partitioners( + programs, is_windows=(backend == "cuda-windows") + ) else: print("\nLowering to ExecuTorch...") partitioner = [] @@ -442,11 +473,22 @@ def main(): "--backend", type=str, default="portable", - choices=["portable", "xnnpack", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], help="Backend for acceleration (default: portable)", ) + parser.add_argument( + "--dtype", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16"], + help="Model dtype for Metal/CUDA backends (default: fp32)", + ) args = parser.parse_args() + # Validate dtype for Metal backend + if args.backend == "metal" and args.dtype == "fp16": + parser.error("Metal backend only supports fp32 and bf16, not fp16") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") @@ -455,6 +497,14 @@ def main(): print("Loading model...") model = load_model() + # Convert model to specified dtype for Metal/CUDA backends + if args.dtype == "bf16": + print("Converting model to bfloat16...") + model = model.to(torch.bfloat16) + elif args.dtype == "fp16": + print("Converting model to float16...") + model = model.to(torch.float16) + print("\nExporting components...") programs, metadata = export_all(model)