diff --git a/CMakeLists.txt b/CMakeLists.txt index 10e2eb437e3..f5091a2af2e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -605,15 +605,23 @@ if(EXECUTORCH_BUILD_CORTEX_M) list(APPEND _executorch_backends coretex_m_backend) endif() -if(EXECUTORCH_BUILD_CUDA) - # Build common AOTI functionality (required for CUDA) +# Build common AOTI functionality if needed by CUDA or Metal backends +if(EXECUTORCH_BUILD_CUDA OR EXECUTORCH_BUILD_METAL) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti) +endif() + +if(EXECUTORCH_BUILD_CUDA) # Build CUDA-specific AOTI functionality add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda) # Add aoti_cuda to backends - it already depends on aoti_common list(APPEND _executorch_backends aoti_cuda) endif() +if(EXECUTORCH_BUILD_METAL) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/metal) + list(APPEND _executorch_backends metal_backend) +endif() + if(EXECUTORCH_BUILD_EXTENSION_APPLE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/apple) endif() diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index 2c836101c5e..c8a4d30c2bf 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -42,9 +42,13 @@ target_compile_options( $<$>:-fexceptions -frtti -fPIC> ) # Ensure symbols are exported properly -target_link_options( - aoti_common PUBLIC $<$>:-Wl,--export-dynamic> -) +if(APPLE) + target_link_options(aoti_common PUBLIC -Wl,-export_dynamic) +else() + target_link_options( + aoti_common PUBLIC $<$>:-Wl,--export-dynamic> + ) +endif() # Link against ExecuTorch libraries and standard libraries target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS}) diff --git a/backends/apple/metal/CMakeLists.txt b/backends/apple/metal/CMakeLists.txt new file mode 100644 index 00000000000..7bdf142041d --- /dev/null +++ b/backends/apple/metal/CMakeLists.txt @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI Metal backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +cmake_minimum_required(VERSION 3.29) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(NOT APPLE) + message(FATAL_ERROR "Metal backend requires macOS") +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +# Use full torch package to get library paths, but only link specific libraries +find_package_torch() + +set(_aoti_metal_sources + runtime/metal_backend.cpp + runtime/shims/memory.cpp + runtime/shims/et_metal.mm + runtime/shims/et_metal_ops.mm + runtime/shims/shim_mps.mm + runtime/shims/tensor_attribute.cpp + runtime/shims/utils.cpp +) + +add_library(metal_backend STATIC ${_aoti_metal_sources}) +target_include_directories( + metal_backend + PUBLIC $ $ + # PyTorch AOTI headers from ExecutorTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) + +# Link Metal framework +find_library(METAL_LIBRARY Metal REQUIRED) +find_library(FOUNDATION_LIBRARY Foundation REQUIRED) +find_library(METALPERFORMANCESHADERS_LIBRARY MetalPerformanceShaders REQUIRED) +find_library( + METALPERFORMANCESHADERSGRAPH_LIBRARY MetalPerformanceShadersGraph REQUIRED +) +target_link_libraries( + metal_backend + PUBLIC ${METAL_LIBRARY} ${FOUNDATION_LIBRARY} + ${METALPERFORMANCESHADERS_LIBRARY} + ${METALPERFORMANCESHADERSGRAPH_LIBRARY} +) + +target_compile_options(metal_backend PUBLIC -fexceptions -frtti -fPIC) + +target_link_options(metal_backend PUBLIC -Wl,-export_dynamic) + +# Find PyTorch's OpenMP library specifically for libtorch-less AOTI +get_torch_base_path(TORCH_BASE_PATH) +find_library( + TORCH_OMP_LIBRARY + NAMES omp libomp + PATHS "${TORCH_BASE_PATH}/lib" + NO_DEFAULT_PATH +) + +if(TORCH_OMP_LIBRARY) + message(STATUS "Found PyTorch OpenMP library: ${TORCH_OMP_LIBRARY}") + # Get the directory containing the OpenMP library for rpath + get_filename_component(TORCH_OMP_LIB_DIR ${TORCH_OMP_LIBRARY} DIRECTORY) + message(STATUS "OpenMP library directory: ${TORCH_OMP_LIB_DIR}") +else() + message( + WARNING "PyTorch OpenMP library not found, may cause runtime linking issues" + ) +endif() + +# Link against appropriate backends and standard libraries +target_link_libraries( + metal_backend PUBLIC aoti_common extension_tensor ${CMAKE_DL_LIBS} + ${TORCH_OMP_LIBRARY} +) + +# Set rpath for OpenMP library to avoid runtime linking issues +if(TORCH_OMP_LIBRARY AND TORCH_OMP_LIB_DIR) + # Add the OpenMP library directory to the rpath + set_target_properties( + metal_backend PROPERTIES BUILD_RPATH "${TORCH_OMP_LIB_DIR}" + INSTALL_RPATH "${TORCH_OMP_LIB_DIR}" + ) + # Also try common OpenMP library locations + target_link_options( + metal_backend PUBLIC -Wl,-rpath,${TORCH_OMP_LIB_DIR} + -Wl,-rpath,/usr/local/opt/libomp/lib + -Wl,-rpath,/opt/homebrew/opt/libomp/lib + ) + message(STATUS "Added rpath for OpenMP library: ${TORCH_OMP_LIB_DIR}") +endif() + +executorch_target_link_options_shared_lib(metal_backend) +install( + TARGETS metal_backend + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/apple/metal/runtime/metal_backend.cpp b/backends/apple/metal/runtime/metal_backend.cpp new file mode 100644 index 00000000000..97b273d428f --- /dev/null +++ b/backends/apple/metal/runtime/metal_backend.cpp @@ -0,0 +1,546 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Include AOTI common headers (from aoti_common library) +#include +#include + +// Include our Metal-specific shim layer headers +#include +#include +#include +#include +#include + +namespace executorch::backends::metal { + +#define LOAD_SYMBOL(handle, member, name, so_handle) \ + do { \ + handle->member = reinterpret_cast(dlsym(so_handle, #name)); \ + ET_CHECK_OR_RETURN_ERROR( \ + handle->member != nullptr, AccessFailed, "Failed to load " #name); \ + } while (0) + +using namespace std; +using namespace aoti; + +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::etensor::Tensor; + +class ET_EXPERIMENTAL MetalBackend final + : public ::executorch::runtime::BackendInterface { + private: + Error load_function_pointers_into_handle( + void* so_handle, + AOTIDelegateHandle* handle) const { + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loading symbols"); + + LOAD_SYMBOL( + handle, + create_with_device, + AOTInductorModelContainerCreateWithDevice, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerCreateWithDevice"); + + LOAD_SYMBOL( + handle, delete_container, AOTInductorModelContainerDelete, so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerDelete"); + + LOAD_SYMBOL( + handle, + get_num_inputs, + AOTInductorModelContainerGetNumInputs, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerGetNumInputs"); + + LOAD_SYMBOL( + handle, + get_num_outputs, + AOTInductorModelContainerGetNumOutputs, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerGetNumOutputs"); + + LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun"); + + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully"); + return Error::Ok; + } + + public: + // Once in program + MetalBackend() { + ET_LOG(Debug, "MetalBackend ctor"); + } + + bool is_available() const override { + return 1; + } + + // Once per loaded binary blob + Result init( + BackendInitContext& context, + FreeableBuffer* processed, // This will be a empty buffer + ArrayRef compile_specs // This will be my empty list + ) const override { + ET_LOG(Info, "MetalBackend::init - Starting initialization"); + + std::string method_name; + for (const CompileSpec& spec : compile_specs) { + if (std::strcmp(spec.key, "method_name") == 0) { + method_name.assign( + static_cast(spec.value.buffer), + spec.value.nbytes); // no nullptr guarantee, so pass size + break; + } + } + + std::string so_blob_key = + method_name.empty() ? "so_blob" : method_name + "_so_blob"; + ET_LOG(Info, "MetalBackend::init - so_blob_key: %s", so_blob_key.c_str()); + + const NamedDataMap* named_data_map = context.get_named_data_map(); + ET_LOG(Info, "MetalBackend::init - Got named data map: %p", named_data_map); + + ET_LOG( + Info, + "MetalBackend::init - Looking for blob key: %s", + so_blob_key.c_str()); + + auto aoti_metal_buffer = named_data_map->get_data(so_blob_key.c_str()); + ET_CHECK_OR_RETURN_ERROR( + aoti_metal_buffer.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + static_cast(aoti_metal_buffer.error())); + + ET_LOG( + Info, + "MetalBackend::init - Buffer is OK, size: %zu", + aoti_metal_buffer->size()); + + if (aoti_metal_buffer->data() == nullptr) { + ET_LOG(Error, "MetalBackend::init - Buffer data is null"); + return Error::InvalidArgument; + } + + ET_LOG( + Info, + "MetalBackend::init - Buffer data pointer: %p", + aoti_metal_buffer->data()); + + // Generate dynamic temporary file path + filesystem::path temp_dir = filesystem::temp_directory_path(); + filesystem::path so_path = + temp_dir / (so_blob_key + to_string(getpid()) + ".so"); + + // Create a temporary file + ET_LOG( + Info, "MetalBackend::init - Creating temp file: %s", so_path.c_str()); + ofstream outfile(so_path.c_str(), ios::binary); + + // Write the ELF buffer to the temporary file + ET_LOG( + Info, + "Writing %zu bytes to %s", + aoti_metal_buffer->size(), + so_path.c_str()); + + outfile.write( + static_cast(aoti_metal_buffer->data()), + aoti_metal_buffer->size()); + + ET_CHECK_OR_RETURN_ERROR( + outfile, AccessFailed, "Failed to write to file %s", so_path.c_str()); + + // Finish writing the file to disk + outfile.close(); + ET_LOG(Info, "MetalBackend::init - File closed successfully"); + + // Load the ELF using dlopen + void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); + ET_CHECK_OR_RETURN_ERROR( + so_handle != nullptr, + AccessFailed, + "Failed to load shared library: %s", + dlerror()); + + processed->Free(); + + // Create handle and load function pointers into it + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = so_handle; + handle->so_path = so_path.string(); + + // Load function pointers specific to this handle's shared library + ET_CHECK_OK_OR_RETURN_ERROR( + load_function_pointers_into_handle(so_handle, handle)); + + AOTInductorModelContainerHandle container_handle = nullptr; + ET_LOG( + Info, + "MetalBackend::init - About to create AOTI container with device='mps'"); + + ET_CHECK_OK_OR_RETURN_ERROR( + handle->create_with_device(&container_handle, 1, "mps", nullptr)); + + ET_LOG(Info, "container_handle = %p", container_handle); + + handle->container_handle = container_handle; + + ET_LOG(Info, "MetalBackend::init - Initialization completed successfully"); + return (DelegateHandle*)handle; // Return the handle post-processing + } + + // Once per execution + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle_, + Span args) const override { + ET_LOG(Debug, "MetalBackend execute"); + + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + ET_LOG(Debug, "MetalBackend Handle generated"); + + size_t n_inputs; + handle->get_num_inputs(handle->container_handle, &n_inputs); + + size_t n_outputs; + handle->get_num_outputs(handle->container_handle, &n_outputs); + + ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs); + + ET_CHECK_OR_RETURN_ERROR( + n_inputs + n_outputs == args.size(), + InvalidArgument, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()) + + ET_LOG( + Debug, + "number of user input %zd and output %zd generated from AOT Inductor matches ET runner's %zd.", + n_inputs, + n_outputs, + args.size()); + + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + + // NOTE: ExecutorTorch tensors are always on CPU/host memory + // We need to create GPU copies for Metal kernel execution + std::vector gpu_inputs( + n_inputs); // GPU copies for kernel execution + std::vector gpu_outputs( + n_outputs); // GPU tensors for kernel output + + ET_LOG(Debug, "MetalBackend input/output vectors generated"); + + // Process input tensors: ExecutorTorch provides CPU tensors, create GPU + // copies + for (int i = 0; i < n_inputs; i++) { + ET_LOG(Debug, "Processing input %d from args to inputs vector", i); + ET_LOG( + Debug, "is %d input a tensor input? %d", i, int(args[i]->isTensor())); + + // Get tensor dimensions and properties from ExecutorTorch CPU tensor + auto cpu_tensor = &(args[i]->toTensor()); + auto sizes = cpu_tensor->sizes(); + auto scalar_type = cpu_tensor->scalar_type(); + ET_LOG( + Debug, + "MetalBackend input %d scalar_type=%d", + i, + static_cast(scalar_type)); + + // Create GPU tensor with same shape + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_input_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + mps_device_type, // device_type = mps + 0, // device_index = 0 + &gpu_input_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for input %d", i); + return Error::Internal; + } + + // Log the created GPU tensor scalar type + auto gpu_tensor = reinterpret_cast( + gpu_input_handle); + ET_LOG( + Debug, + "MetalBackend created GPU tensor %d scalar_type=%d", + i, + static_cast(gpu_tensor->scalar_type())); + + gpu_inputs[i] = gpu_input_handle; + + // Log the CPU tensor data before copying to GPU + void* cpu_data = cpu_tensor->mutable_data_ptr(); + if (cpu_data && cpu_tensor->numel() > 0) { + float* cpu_float_data = (float*)cpu_data; + ET_LOG( + Debug, + "CPU input %d data before copy: [%.3f, %.3f, %.3f, ...] (numel=%zd)", + i, + cpu_float_data[0], + cpu_float_data[1], + cpu_float_data[2], + cpu_tensor->numel()); + } + + // Copy data from CPU to GPU + Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); + return Error::Internal; + } + + // Log the GPU tensor scalar type after copy + auto gpu_tensor_after = + reinterpret_cast( + gpu_inputs[i]); + ET_LOG( + Debug, + "MetalBackend GPU tensor %d scalar_type after copy=%d", + i, + static_cast(gpu_tensor_after->scalar_type())); + + ET_LOG(Debug, "Successfully copied input %d from CPU to GPU", i); + } + + ET_LOG(Debug, "MetalBackend GPU inputs generated"); + + // Process output tensors: create GPU counterparts for ExecutorTorch CPU + // tensors + for (int i = 0; i < n_outputs; i++) { + // Get output tensor dimensions from ExecutorTorch CPU tensor + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = cpu_output_tensor->sizes(); + auto scalar_type = cpu_output_tensor->scalar_type(); + ET_LOG( + Debug, + "MetalBackend output %d scalar_type=%d", + i, + static_cast(scalar_type)); + + // Create GPU tensor with same shape for kernel output + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_output_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + mps_device_type, // device_type = mps + 0, // device_index = 0 + &gpu_output_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for output %d", i); + return Error::Internal; + } + + gpu_outputs[i] = gpu_output_handle; + ET_LOG(Debug, "Created GPU output tensor %d", i); + } + + ET_LOG(Debug, "MetalBackend output generated"); + + // Log tensor handles before passing to AOTI container + ET_LOG(Debug, "Passing to AOTInductorModelContainerRun:"); + for (int i = 0; i < n_inputs; i++) { + void* gpu_input_data = gpu_inputs[i]->mutable_data_ptr(); + ET_LOG( + Debug, + " gpu_inputs[%d] = %p, data_ptr = %p", + i, + gpu_inputs[i], + gpu_input_data); + } + for (int i = 0; i < n_outputs; i++) { + void* gpu_output_data = gpu_outputs[i]->mutable_data_ptr(); + ET_LOG( + Debug, + " gpu_outputs[%d] = %p, data_ptr = %p", + i, + gpu_outputs[i], + gpu_output_data); + } + + // Run AOTI container with GPU tensors + AOTIRuntimeError error = handle->run( + handle->container_handle, + gpu_inputs.data(), // Use GPU input tensors + n_inputs, + gpu_outputs.data(), // Use GPU output tensors + n_outputs, + nullptr, // Pass the actual Metal stream! + nullptr); // proxy_executor_handle can remain nullptr + + if (error != Error::Ok) { + ET_LOG( + Error, + "AOTInductorModelContainerRun failed with error code %d", + error); + return Error::Internal; + } + + // Ensure all GPU work is completed before reading results + try { + synchronize_metal_stream(); + } catch (const std::exception& e) { + ET_LOG( + Error, + "Failed to synchronize Metal stream after kernel execution: %s", + e.what()); + return Error::Internal; + } catch (...) { + ET_LOG( + Error, + "Failed to synchronize Metal stream after kernel execution: unknown exception"); + return Error::Internal; + } + + ET_LOG(Debug, "MetalBackend running done and synchronized"); + + // Copy GPU output results back to CPU output tensors + for (int i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + // For DYNAMIC_BOUND tensors we try to resize + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), + "Error resizing tensor at output index %d", + i); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), + "Failed to copy GPU output %d back to CPU", + i); + ET_LOG(Debug, "Copied GPU output %d back to CPU", i); + } + + // Clean up GPU tensors that we created (ExecutorTorch tensors are always + // CPU, so all GPU tensors are our copies) + for (int i = 0; i < n_inputs; i++) { + // All GPU input tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_inputs[i]); + } + + for (int i = 0; i < n_outputs; i++) { + // All GPU output tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_outputs[i]); + } + + ET_LOG(Debug, "MetalBackend execution completed successfully"); + + return Error::Ok; + } + + void destroy(DelegateHandle* handle_) const override { + if (handle_ == nullptr) { + return; + } + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO: Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); + + // Now close the shared library + if (handle->so_handle != nullptr) { + dlclose(handle->so_handle); + } + + // Remove the temporary shared library file + if (!handle->so_path.empty()) { + std::error_code remove_error; + std::filesystem::remove(handle->so_path, remove_error); + ET_CHECK_OR_LOG_ERROR( + !remove_error, + "Failed to remove temporary shared library %s: %s", + handle->so_path.c_str(), + remove_error.message().c_str()); + if (!remove_error) { + ET_LOG( + Info, + "Removed temporary shared library file: %s", + handle->so_path.c_str()); + } + } + + delete handle; + cleanup_memory(); + executorch::backends::aoti::cleanup_tensor_metadata(); + ET_LOG(Debug, "MetalBackend handle %p destroy", handle_); + } +}; + +} // namespace executorch::backends::metal + +namespace executorch::backends { +namespace { +auto cls = metal::MetalBackend(); +executorch::runtime::Backend backend{"MetalBackend", &cls}; +static executorch::runtime::Error success_with_compiler = + register_backend(backend); +} // namespace +} // namespace executorch::backends diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h new file mode 100644 index 00000000000..78bdb419ea4 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * ExecutorTorch implementation of aoti_torch_mps_mm_out. + * Performs simple matrix multiplication: out = self @ mat2 + */ +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + +/** + * ExecutorTorch implementation of aoti_torch_mps_convolution. + * Performs 2D convolution operation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0); + +/** + * ExecutorTorch implementation of + * aoti_torch_mps__scaled_dot_product_attention_math_for_mps. Performs scaled + * dot product attention calculation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm new file mode 100644 index 00000000000..0aa90650a1d --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -0,0 +1,1358 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +using executorch::runtime::etensor::Tensor; + +// Forward declaration of dispatch_sync_with_rethrow from et_metal.mm +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +namespace { + +// Helper function to get Metal buffer from the global mapping +static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { + void* data_ptr = tensor->mutable_data_ptr(); + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name); + throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping"); + } + return it->second; +} + +// Helper function to allocate a Metal buffer and register it in the global mapping. +static id allocate_mtl_buffer(void** data_ptr, size_t size_bytes) { + AOTITorchError malloc_err = aoti_torch_mps_malloc(data_ptr, size_bytes); + if (malloc_err != Error::Ok) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: Failed to allocate Metal buffer via aoti_torch_mps_malloc"); + throw std::runtime_error("Failed to allocate output Metal buffer"); + } + + auto it = ptr_to_mtl_buffer.find(*data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: aoti_torch_mps_malloc did not register buffer in map"); + throw std::runtime_error("Failed to look up allocated Metal buffer"); + } + return it->second; +} + +} // namespace + +extern "C" { + +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + ET_LOG(Debug, "aoti_torch_mps_mm_out: Starting with out=%p, self=%p, mat2=%p", + out, self, mat2); + + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_mm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Converted tensor handles to ET tensors"); + + // Validate tensor dimensions + if (self_tensor->dim() != 2 || mat2_tensor->dim() != 2) { + std::string error_msg = "aoti_torch_mps_mm_out: tensors must be 2-D, got " + + std::to_string(self_tensor->dim()) + " and " + + std::to_string(mat2_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + int64_t M = self_tensor->sizes()[0]; // rows of self + int64_t K = self_tensor->sizes()[1]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[1]; // cols of mat2 + + // Check matrix multiplication compatibility + if (self_tensor->sizes()[1] != mat2_tensor->sizes()[0]) { + std::string error_msg = "aoti_torch_mps_mm_out: incompatible matrix sizes for mm (" + + std::to_string(M) + "x" + std::to_string(K) + " and " + + std::to_string(mat2_tensor->sizes()[0]) + "x" + std::to_string(N) + ")"; + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_mm_out: self shape: [%d, %d], mat2 shape: [%d, %d], out shape: [%d, %d]", + (int)M, (int)K, (int)mat2_tensor->sizes()[0], (int)N, + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0); + + // Check if mat2 is transposed (non-contiguous due to transpose) + // A transposed matrix will have stride(-2) == 1 (column-major instead of row-major) + // For a 2D tensor with shape [K, N]: + // - Contiguous (row-major): strides = [N, 1] + // - Transposed (column-major): strides = [1, K] + bool mat2_is_transposed = false; + int64_t mat2_stride_0 = mat2_tensor->strides()[0]; // stride for dimension 0 + int64_t mat2_stride_1 = mat2_tensor->strides()[1]; // stride for dimension 1 + + // Detect transposed layout: stride(-2) == 1 indicates column-major layout + if (mat2_stride_0 == 1 && mat2_stride_1 != 1) { + mat2_is_transposed = true; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is transposed (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } else { + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is contiguous (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_mm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_mm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_mm_out", "out"); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using existing Metal buffers - self=%p, mat2=%p, out=%p", + self_buffer, mat2_buffer, out_buffer); + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Determine data type and element size + int32_t dtype = static_cast(self_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::BFLOAT16=%d", + dtype, static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16)); + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_mm_out: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for matrix multiplication"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: dtype=%d, element_size=%zu", dtype, element_size); + ET_LOG(Debug, "aoti_torch_mps_mm_out: M=%lld, K=%lld, N=%lld", M, K, N); + + // Create MPSGraph for matrix multiplication + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created MPSGraph instance"); + + // Define tensor shapes for placeholders + NSArray* selfShape = @[@(M), @(K)]; + NSArray* outShape = @[@(M), @(N)]; + + // For mat2, we need to handle both contiguous and transposed cases + // If mat2 is transposed, its physical layout in memory is [N, K] (column-major) + // but logically we need [K, N] for the matrix multiplication + NSArray* mat2PhysicalShape; + if (mat2_is_transposed) { + // Physical shape reflects the actual memory layout (transposed) + mat2PhysicalShape = @[@(N), @(K)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (transposed): [%d,%d]", (int)N, (int)K); + } else { + // Physical shape is the logical shape (contiguous) + mat2PhysicalShape = @[@(K), @(N)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (contiguous): [%d,%d]", (int)K, (int)N); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", + (int)M, (int)K, + mat2_is_transposed ? (int)N : (int)K, + mat2_is_transposed ? (int)K : (int)N); + + // Create placeholders for input tensors + MPSGraphTensor* selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape + dataType:mps_dtype + name:@"mat2_physical"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + + // If mat2 is transposed, apply transpose operation in the graph to get the logical shape + MPSGraphTensor* mat2Logical; + if (mat2_is_transposed) { + // Transpose from physical [N, K] to logical [K, N] + // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors + mat2Logical = [mpsGraph transposeTensor:mat2Placeholder + dimension:-2 + withDimension:-1 + name:@"mat2_transposed"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); + } else { + // No transpose needed, use placeholder directly + mat2Logical = mat2Placeholder; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); + } + + // Perform matrix multiplication using MPSGraph with the logical mat2 tensor + MPSGraphTensor* mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Logical + name:@"matrix_multiplication"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + // Use physical shapes to match how data is actually laid out in memory + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + MPSGraphTensorData* mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2PhysicalShape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created feeds dictionary with physical shapes"); + + // Create results dictionary + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + NSDictionary* results = @{mmOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_mm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_mm_out: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Starting with input=%p, weight=%p, bias=%p, groups=%lld, transposed=%d", + input, weight, bias, groups, transposed); + + if (!input || !weight || !ret0) { + ET_LOG(Error, "aoti_torch_mps_convolution: null required handles (input, weight, or ret0)"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto input_tensor = reinterpret_cast(input); + auto weight_tensor = reinterpret_cast(weight); + + // bias can be null for convolutions without bias + Tensor* bias_tensor = nullptr; + if (bias && *bias) { + bias_tensor = reinterpret_cast(*bias); + ET_LOG(Debug, "aoti_torch_mps_convolution: Has bias tensor"); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: No bias tensor"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Converted tensor handles to ET tensors"); + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_convolution: input shape: [%d, %d, %d, %d]", + input_tensor->dim() > 0 ? (int)input_tensor->sizes()[0] : 0, + input_tensor->dim() > 1 ? (int)input_tensor->sizes()[1] : 0, + input_tensor->dim() > 2 ? (int)input_tensor->sizes()[2] : 0, + input_tensor->dim() > 3 ? (int)input_tensor->sizes()[3] : 0); + + ET_LOG(Debug, "aoti_torch_mps_convolution: weight shape: [%d, %d, %d, %d]", + weight_tensor->dim() > 0 ? (int)weight_tensor->sizes()[0] : 0, + weight_tensor->dim() > 1 ? (int)weight_tensor->sizes()[1] : 0, + weight_tensor->dim() > 2 ? (int)weight_tensor->sizes()[2] : 0, + weight_tensor->dim() > 3 ? (int)weight_tensor->sizes()[3] : 0); + + // Log convolution parameters + if (stride && stride_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: stride: [%lld, %lld]", stride[0], stride[1]); + } + if (padding && padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: padding: [%lld, %lld]", padding[0], padding[1]); + } + if (dilation && dilation_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: dilation: [%lld, %lld]", dilation[0], dilation[1]); + } + if (output_padding && output_padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: output_padding: [%lld, %lld]", output_padding[0], output_padding[1]); + } + + // Support conv1d and conv2d by inspecting weight rank. + // conv1d: weight dims = [C_out, C_in, K] + // conv2d: weight dims = [C_out, C_in, Kh, Kw] + bool is_conv1d = (weight_tensor->dim() == 3); + + // Accept input ranks: + // conv1d: 2D (C,W) or 3D (N,C,W) + // conv2d: 3D (C,H,W) or 4D (N,C,H,W) + bool has_batch_dim = false; + bool is_input_4d = false; + int64_t N = 1, C_in = 0, H_in = 1, W_in = 0; + if (is_conv1d) { + if (input_tensor->dim() == 2) { + // (C, W) + has_batch_dim = false; + C_in = input_tensor->sizes()[0]; + W_in = input_tensor->sizes()[1]; + H_in = 1; + } else if (input_tensor->dim() == 3) { + // (N, C, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + H_in = 1; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv1d expects 2D or 3D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } else { + is_input_4d = (input_tensor->dim() == 4); + if (is_input_4d) { + // (N, C, H, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + H_in = input_tensor->sizes()[2]; + W_in = input_tensor->sizes()[3]; + } else if (input_tensor->dim() == 3) { + // (C, H, W) + has_batch_dim = false; + N = 1; + C_in = input_tensor->sizes()[0]; + H_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv2d expects 3D or 4D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } + + // Get weight dimensions + int64_t C_out = weight_tensor->sizes()[0]; // output channels + int64_t kernel_h = is_conv1d ? 1 : weight_tensor->sizes()[2]; // kernel height + int64_t kernel_w = is_conv1d ? weight_tensor->sizes()[2] : weight_tensor->sizes()[3]; // kernel width + + // Calculate output spatial dimensions + int64_t stride_h = is_conv1d ? 1 : (stride && stride_len_ > 0 ? stride[0] : 1); + int64_t stride_w = is_conv1d ? (stride && stride_len_ > 0 ? stride[0] : 1) + : (stride && stride_len_ > 1 ? stride[1] : 1); + int64_t pad_h = is_conv1d ? 0 : (padding && padding_len_ > 0 ? padding[0] : 0); + int64_t pad_w = is_conv1d ? (padding && padding_len_ > 0 ? padding[0] : 0) + : (padding && padding_len_ > 1 ? padding[1] : 0); + int64_t dil_h = is_conv1d ? 1 : (dilation && dilation_len_ > 0 ? dilation[0] : 1); + int64_t dil_w = is_conv1d ? (dilation && dilation_len_ > 0 ? dilation[0] : 1) + : (dilation && dilation_len_ > 1 ? dilation[1] : 1); + + int64_t H_out, W_out; + if (transposed) { + // For transposed convolution, output size calculation is different + int64_t output_pad_h = is_conv1d ? 0 : (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0); + int64_t output_pad_w = is_conv1d ? (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0) + : (output_padding && output_padding_len_ > 1 ? output_padding[1] : 0); + H_out = is_conv1d ? 1 : ((H_in - 1) * stride_h - 2 * pad_h + dil_h * (kernel_h - 1) + output_pad_h + 1); + W_out = (W_in - 1) * stride_w - 2 * pad_w + dil_w * (kernel_w - 1) + output_pad_w + 1; + } else { + // Regular convolution output size calculation + H_out = is_conv1d ? 1 : ((H_in + 2 * pad_h - dil_h * (kernel_h - 1) - 1) / stride_h + 1); + W_out = (W_in + 2 * pad_w - dil_w * (kernel_w - 1) - 1) / stride_w + 1; + } + + if (!is_conv1d && is_input_4d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 4D output shape: [%lld, %lld, %lld, %lld]", N, C_out, H_out, W_out); + } else if (!is_conv1d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D output shape: [%lld, %lld, %lld]", C_out, H_out, W_out); + } else if (is_conv1d && has_batch_dim) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D (1D conv) output shape: [%lld, %lld, %lld]", N, C_out, W_out); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 2D (1D conv) output shape: [%lld, %lld]", C_out, W_out); + } + + // Validate output dimensions are positive + if (N <= 0 || C_out <= 0 || H_out <= 0 || W_out <= 0) { + ET_LOG(Error, "aoti_torch_mps_convolution: Invalid output dimensions N=%lld, C_out=%lld, H_out=%lld, W_out=%lld", + N, C_out, H_out, W_out); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Ensure stream is ready; command buffer handled internally by stream helpers + + // Determine data type and element size + int32_t dtype = static_cast(input_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for convolution"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + + // Create MPSGraph for convolution + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created MPSGraph instance"); + + // Define tensor shapes for placeholders (always 4D NCHW for MPSGraph) + NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; + NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", + (int)N, (int)C_in, (int)H_in, (int)W_in, + (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); + + // Create placeholders for input tensors + MPSGraphTensor* inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + MPSGraphTensor* weightPlaceholder = [mpsGraph placeholderWithShape:weightShape + dataType:mps_dtype + name:@"weight"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); + + // Create convolution descriptor + MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:pad_w + paddingRight:pad_w + paddingTop:pad_h + paddingBottom:pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", + stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); + + // Perform convolution using MPSGraph + MPSGraphTensor* convOutput = nil; + if (transposed) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); + // For transposed convolution, we need to handle output padding + int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; + int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; + + // For transposed convolution, we need to adjust the padding calculation + // In transposed convolution, the effective padding is typically negative + // and we use output_padding to control the final output size + int64_t transposed_pad_h = pad_h - output_pad_h; + int64_t transposed_pad_w = pad_w - output_pad_w; + + // Create transposed convolution descriptor with adjusted padding + MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:transposed_pad_w + paddingRight:transposed_pad_w + paddingTop:transposed_pad_h + paddingBottom:transposed_pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:transposedConvDesc + name:@"transposed_convolution"]; + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:convDesc + name:@"convolution"]; + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); + + // Handle bias if provided + MPSGraphTensor* finalOutput = convOutput; + MPSGraphTensor* biasPlaceholder = nil; + if (bias_tensor) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); + + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; + + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); + } + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Get Metal buffers from tensors + id input_buffer = get_mtl_buffer(input_tensor, "aoti_torch_mps_convolution", "input"); + id weight_buffer = get_mtl_buffer(weight_tensor, "aoti_torch_mps_convolution", "weight"); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Using existing Metal buffers - input=%p, weight=%p", + input_buffer, weight_buffer); + + // Create MPSGraphTensorData objects for input tensors + MPSGraphTensorData* inputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:input_buffer + shape:inputShape + dataType:mps_dtype]; + MPSGraphTensorData* weightData = [[MPSGraphTensorData alloc] initWithMTLBuffer:weight_buffer + shape:weightShape + dataType:mps_dtype]; + + feeds[inputPlaceholder] = inputData; + feeds[weightPlaceholder] = weightData; + + // Add bias data to feeds if provided + if (bias_tensor && biasPlaceholder) { + id bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias"); + + NSArray* biasShape = @[@(C_out)]; + MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; + + feeds[biasPlaceholder] = biasData; + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created feeds dictionary"); + + // Create Metal buffer for output tensor + size_t output_size_bytes = N * C_out * H_out * W_out * element_size; + void* output_contents_ptr = nullptr; + id output_buffer = allocate_mtl_buffer(&output_contents_ptr, output_size_bytes); + + // Create results dictionary (MPSGraph output is 4D) + NSArray* outputShape = @[@(N), @(C_out), @(H_out), @(W_out)]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:output_buffer + shape:outputShape + dataType:mps_dtype]; + + NSDictionary* results = @{finalOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_convolution: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } @catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: MPSGraph execution failed"); + throw std::runtime_error("MPSGraph execution failed"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: MPSGraph execution completed successfully"); + + // Create output tensor handle on device (MPS) that points to GPU buffer + std::vector output_sizes_int64; + std::vector output_strides; + if (!is_conv1d && is_input_4d) { + output_sizes_int64 = {N, C_out, H_out, W_out}; + // Contiguous NCHW strides + output_strides = { + C_out * H_out * W_out, + H_out * W_out, + W_out, + 1 + }; + } else if (!is_conv1d) { + output_sizes_int64 = {C_out, H_out, W_out}; + // Contiguous CHW strides + output_strides = { + H_out * W_out, + W_out, + 1 + }; + } else if (is_conv1d && has_batch_dim) { + output_sizes_int64 = {N, C_out, W_out}; + // Contiguous NCW strides + output_strides = { + C_out * W_out, + W_out, + 1 + }; + } else { + output_sizes_int64 = {C_out, W_out}; + // Contiguous CW strides + output_strides = { + W_out, + 1 + }; + } + + // Use the GPU buffer contents pointer directly for the tensor storage + void* tensor_data = output_contents_ptr; + + AOTITensorHandle output_tensor_handle = nullptr; + + AOTITorchError create_result = aoti_torch_create_tensor_from_blob_v2( + tensor_data, + static_cast(output_sizes_int64.size()), // ndim + output_sizes_int64.data(), + output_strides.data(), + 0, // storage_offset + dtype, // dtype + 13, // device_type (MPS) + 0, // device_index + &output_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_result != Error::Ok || !output_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to create output tensor, error code: %d", static_cast(create_result)); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Failed to create output tensor"); + } + + // Verify the tensor was created with the correct size + auto* et_tensor = reinterpret_cast(output_tensor_handle); + size_t actual_numel = et_tensor->numel(); + size_t expected_numel = static_cast(N * C_out * H_out * W_out); + + if (actual_numel != expected_numel) { + ET_LOG(Error, "aoti_torch_mps_convolution: Tensor size mismatch. Expected %zu, got %zu", expected_numel, actual_numel); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Tensor size mismatch"); + } + + // Store the tensor handle - mark that we own the memory since we manually allocated it with malloc + *ret0 = output_tensor_handle; + is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_convolution exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1) { + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with MPSGraph implementation"); + + if (!query || !key || !value || !ret0 || !ret1) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: null required tensor handles"); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get current Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto* query_tensor = reinterpret_cast(query); + auto* key_tensor = reinterpret_cast(key); + auto* value_tensor = reinterpret_cast(value); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); + + // Validate tensor dimensions + if (query_tensor->dim() < 3 || key_tensor->dim() < 3 || value_tensor->dim() < 3) { + std::string error_msg = "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: tensors must be at least 3-D, got " + + std::to_string(query_tensor->dim()) + ", " + + std::to_string(key_tensor->dim()) + ", " + + std::to_string(value_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Get tensor dimensions (assuming [batch, num_heads, seq_len, head_dim] format) + int64_t batchSize = query_tensor->sizes()[0]; + int64_t num_heads = query_tensor->sizes()[1]; + int64_t qSize = query_tensor->sizes()[2]; + int64_t headSize = query_tensor->sizes()[3]; + int64_t kvSeqLength = key_tensor->sizes()[2]; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", + batchSize, num_heads, qSize, headSize, kvSeqLength); + + // Detect non-contiguous layouts for query, key, and value tensors + // For a 4D tensor [batch, num_heads, seq_len, head_dim], common non-contiguous patterns: + // - Transposed last 2 dims (dims 2,3): strides[2] == 1 && strides[3] == seq_len (seq_len and head_dim swapped) + // - Transposed internal dims (dims 1,2): strides[1] == head_dim && strides[2] == num_heads*head_dim (num_heads and seq_len swapped) + // - Other permutations may exist depending on upstream operations + + bool query_is_transposed_last2 = false; // transpose of dims -2 and -1 + bool query_is_transposed_internal = false; // transpose of dims 1 and 2 + bool key_is_transposed_last2 = false; + bool key_is_transposed_internal = false; + bool value_is_transposed_last2 = false; + bool value_is_transposed_internal = false; + + // Expected contiguous strides for query [batch, num_heads, qSize, headSize] + int64_t expected_q_stride_3 = 1; + int64_t expected_q_stride_2 = headSize; + int64_t expected_q_stride_1 = qSize * headSize; + int64_t expected_q_stride_0 = num_heads * qSize * headSize; + + // Check query tensor layout + auto q_strides = query_tensor->strides(); + if (q_strides[3] != expected_q_stride_3 || q_strides[2] != expected_q_stride_2 || + q_strides[1] != expected_q_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (q_strides[2] == 1 && q_strides[3] == qSize && q_strides[1] == qSize * headSize) { + query_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (q_strides[1] == headSize && q_strides[2] == num_heads * headSize && q_strides[3] == 1) { + query_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + + // Expected contiguous strides for key [batch, num_heads, kvSeqLength, headSize] + int64_t expected_k_stride_3 = 1; + int64_t expected_k_stride_2 = headSize; + int64_t expected_k_stride_1 = kvSeqLength * headSize; + int64_t expected_k_stride_0 = num_heads * kvSeqLength * headSize; + + // Check key tensor layout + auto k_strides = key_tensor->strides(); + if (k_strides[3] != expected_k_stride_3 || k_strides[2] != expected_k_stride_2 || + k_strides[1] != expected_k_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (k_strides[2] == 1 && k_strides[3] == kvSeqLength && k_strides[1] == kvSeqLength * headSize) { + key_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (k_strides[1] == headSize && k_strides[2] == num_heads * headSize && k_strides[3] == 1) { + key_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + + // Expected contiguous strides for value [batch, num_heads, kvSeqLength, headSize] + int64_t expected_v_stride_3 = 1; + int64_t expected_v_stride_2 = headSize; + int64_t expected_v_stride_1 = kvSeqLength * headSize; + int64_t expected_v_stride_0 = num_heads * kvSeqLength * headSize; + + // Check value tensor layout + auto v_strides = value_tensor->strides(); + if (v_strides[3] != expected_v_stride_3 || v_strides[2] != expected_v_stride_2 || + v_strides[1] != expected_v_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (v_strides[2] == 1 && v_strides[3] == kvSeqLength && v_strides[1] == kvSeqLength * headSize) { + value_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (v_strides[1] == headSize && v_strides[2] == num_heads * headSize && v_strides[3] == 1) { + value_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + + // Determine data type and element size + int32_t dtype = static_cast(query_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for scaled dot product attention"); + } + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + + // Check that headSize is not zero to avoid division by zero + if (headSize == 0) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero"); + throw std::runtime_error("headSize must be non-zero for scaled dot product attention"); + } + + // Calculate scale factor + double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers for query, key and value tensors + id query_buffer = get_mtl_buffer(query_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "query"); + id key_buffer = get_mtl_buffer(key_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "key"); + id value_buffer = get_mtl_buffer(value_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "value"); + + // Calculate output tensor dimensions + std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; + std::vector attn_sizes = {batchSize, num_heads, qSize, kvSeqLength}; + + // Calculate strides for contiguous tensors + std::vector out_strides = { + num_heads * qSize * headSize, + qSize * headSize, + headSize, + 1 + }; + + std::vector attn_strides = { + num_heads * qSize * kvSeqLength, + qSize * kvSeqLength, + kvSeqLength, + 1 + }; + + // Allocate output Metal buffers via AOTI API to keep GPU residency and reuse + size_t out_size_bytes = batchSize * num_heads * qSize * headSize * element_size; + size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; + + void* out_contents_ptr = nullptr; + id out_buffer = allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); + + void* attn_contents_ptr = nullptr; + id attn_weights_buffer = allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Method 1: Using MPSGraph scaledDotProductAttention API - with detailed error handling + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention"); + + @try { + // Check if scaledDotProductAttentionWithQueryTensor is available + MPSGraph* testGraph = [MPSGraph new]; + if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system"); + throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system"); + } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available"); + + // Create MPSGraph for scaled dot product attention + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); + + // Define physical tensor shapes for placeholders (matching actual memory layout) + // Two transpose patterns supported: + // 1. Last 2 dims transposed (dims 2,3): [batch, num_heads, head_dim, seq_len] + // 2. Internal dims transposed (dims 1,2): [batch, seq_len, num_heads, head_dim] + NSArray* queryPhysicalShape; + NSArray* keyPhysicalShape; + NSArray* valuePhysicalShape; + + if (query_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, qSize] (dims 2,3 swapped) + queryPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(qSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)qSize); + } else if (query_is_transposed_internal) { + // Physical layout: [batch, qSize, num_heads, headSize] (dims 1,2 swapped) + queryPhysicalShape = @[@(batchSize), @(qSize), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)qSize, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, qSize, headSize] + queryPhysicalShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)qSize, (int)headSize); + } + + if (key_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) + keyPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); + } else if (key_is_transposed_internal) { + // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) + keyPhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] + keyPhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + } + + if (value_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) + valuePhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); + } else if (value_is_transposed_internal) { + // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) + valuePhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] + valuePhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + } + + // Create placeholders for input tensors with physical shapes + MPSGraphTensor* queryPlaceholder = [mpsGraph placeholderWithShape:queryPhysicalShape + dataType:mps_dtype + name:@"query_physical"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created query placeholder"); + + MPSGraphTensor* keyPlaceholder = [mpsGraph placeholderWithShape:keyPhysicalShape + dataType:mps_dtype + name:@"key_physical"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created key placeholder"); + + MPSGraphTensor* valuePlaceholder = [mpsGraph placeholderWithShape:valuePhysicalShape + dataType:mps_dtype + name:@"value_physical"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created value placeholder"); + + // Apply transpose operations in the graph to convert physical to logical layout + // Logical shapes needed for SDPA: Q[batch, num_heads, qSize, headSize], + // K[batch, num_heads, kvSeqLength, headSize], + // V[batch, num_heads, kvSeqLength, headSize] + MPSGraphTensor* queryLogical; + MPSGraphTensor* keyLogical; + MPSGraphTensor* valueLogical; + + if (query_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, qSize] → [batch, num_heads, qSize, headSize] + queryLogical = [mpsGraph transposeTensor:queryPlaceholder + dimension:-2 + withDimension:-1 + name:@"query_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to query tensor in graph"); + } else if (query_is_transposed_internal) { + // Transpose dims 1,2: [batch, qSize, num_heads, headSize] → [batch, num_heads, qSize, headSize] + queryLogical = [mpsGraph transposeTensor:queryPlaceholder + dimension:1 + withDimension:2 + name:@"query_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to query tensor in graph"); + } else { + queryLogical = queryPlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using query placeholder directly (no transpose needed)"); + } + + if (key_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] + keyLogical = [mpsGraph transposeTensor:keyPlaceholder + dimension:-2 + withDimension:-1 + name:@"key_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to key tensor in graph"); + } else if (key_is_transposed_internal) { + // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] + keyLogical = [mpsGraph transposeTensor:keyPlaceholder + dimension:1 + withDimension:2 + name:@"key_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to key tensor in graph"); + } else { + keyLogical = keyPlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using key placeholder directly (no transpose needed)"); + } + + if (value_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] + valueLogical = [mpsGraph transposeTensor:valuePlaceholder + dimension:-2 + withDimension:-1 + name:@"value_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to value tensor in graph"); + } else if (value_is_transposed_internal) { + // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] + valueLogical = [mpsGraph transposeTensor:valuePlaceholder + dimension:1 + withDimension:2 + name:@"value_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to value tensor in graph"); + } else { + valueLogical = valuePlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using value placeholder directly (no transpose needed)"); + } + + MPSGraphTensor* maskTensor = nil; + + // Handle causal mask + if (is_causal) { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Creating causal mask"); + + // Create a causal mask: lower triangular matrix filled with 0s, upper triangle with -inf + // Shape should be [qSize, kvSeqLength] + NSArray* maskShape = @[@(qSize), @(kvSeqLength)]; + + // Create ones tensor + MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f + shape:maskShape + dataType:mps_dtype]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created ones tensor for causal mask"); + + // Create lower triangular mask (including diagonal) + MPSGraphTensor* causalMask = [mpsGraph bandPartWithTensor:onesTensor + numLower:-1 + numUpper:0 + name:@"causal_mask"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created causal mask using bandPartWithTensor"); + + // Convert mask to attention weights format: 0 for allowed positions, -inf for masked + MPSGraphTensor* zerosTensor = [mpsGraph constantWithScalar:0.0f + shape:maskShape + dataType:mps_dtype]; + + MPSGraphTensor* negInfTensor = [mpsGraph constantWithScalar:-1e9f + shape:maskShape + dataType:mps_dtype]; + + // Select: where causal_mask == 1, use 0.0, else use -inf + maskTensor = [mpsGraph selectWithPredicateTensor:causalMask + truePredicateTensor:zerosTensor + falsePredicateTensor:negInfTensor + name:@"causal_mask_final"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created final causal mask using selectWithPredicateTensor"); + } + + // Handle explicit attention mask if provided + MPSGraphTensor* explicitMaskPlaceholder = nil; + if (attn_mask && *attn_mask) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Adding explicit attention mask"); + + // Create mask placeholder + NSMutableArray* maskShapeArray = [NSMutableArray array]; + for (int i = 0; i < mask_tensor->dim(); i++) { + [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; + } + + explicitMaskPlaceholder = [mpsGraph placeholderWithShape:maskShapeArray + dataType:mps_dtype + name:@"attention_mask"]; + + if (maskTensor) { + // Combine causal and explicit masks + maskTensor = [mpsGraph additionWithPrimaryTensor:maskTensor + secondaryTensor:explicitMaskPlaceholder + name:@"combined_mask"]; + } else { + maskTensor = explicitMaskPlaceholder; + } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created explicit mask placeholder"); + } + + // Perform scaled dot product attention using MPSGraph with logical (possibly transposed) tensors + // The logical tensors have the correct shapes for attention computation regardless of input memory layout + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Calling scaledDotProductAttentionWithQueryTensor with scale=%f", scale_factor); + + MPSGraphTensor* outputTensor = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryLogical + keyTensor:keyLogical + valueTensor:valueLogical + maskTensor:maskTensor + scale:scale_factor + name:@"scaled_dot_product_attention"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Successfully created SDPA tensor"); + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created feeds dictionary"); + + // Create MPSGraphTensorData objects for input tensors using physical shapes + // Physical shapes match the actual memory layout of the tensors + MPSGraphTensorData* queryData = [[MPSGraphTensorData alloc] initWithMTLBuffer:query_buffer + shape:queryPhysicalShape + dataType:mps_dtype]; + MPSGraphTensorData* keyData = [[MPSGraphTensorData alloc] initWithMTLBuffer:key_buffer + shape:keyPhysicalShape + dataType:mps_dtype]; + MPSGraphTensorData* valueData = [[MPSGraphTensorData alloc] initWithMTLBuffer:value_buffer + shape:valuePhysicalShape + dataType:mps_dtype]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraphTensorData objects with physical shapes"); + + feeds[queryPlaceholder] = queryData; + feeds[keyPlaceholder] = keyData; + feeds[valuePlaceholder] = valueData; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds"); + + // Add explicit mask data to feeds if provided + if (explicitMaskPlaceholder && attn_mask && *attn_mask) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + // Get Metal buffer for mask + id mask_buffer = get_mtl_buffer(mask_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "mask"); + + NSMutableArray* maskShapeArray = [NSMutableArray array]; + for (int i = 0; i < mask_tensor->dim(); i++) { + [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; + } + + MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer + shape:maskShapeArray + dataType:mps_dtype]; + feeds[explicitMaskPlaceholder] = maskData; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds"); + } + + // Create results dictionary + NSArray* outputShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outputShape + dataType:mps_dtype]; + + NSDictionary* results = @{outputTensor: outputData}; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created results dictionary"); + + // Execute via shared stream and keep results on GPU + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream"); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully"); + + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph operation failed with NSException"); + } + + // For attention weights, zero-fill the GPU buffer (shared memory allows CPU memset) + std::memset(attn_contents_ptr, 0, attn_size_bytes); + + // Create output tensor handles + AOTITensorHandle out_tensor_handle = nullptr; + AOTITensorHandle attn_tensor_handle = nullptr; + + AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( + out_contents_ptr, + 4, // ndim + output_sizes.data(), + out_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &out_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( + attn_contents_ptr, + 4, // ndim + attn_sizes.data(), + attn_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &attn_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_out_result != Error::Ok || create_attn_result != Error::Ok || + !out_tensor_handle || !attn_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create output tensors"); + aoti_torch_mps_free(out_contents_ptr); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create output tensors"); + } + + // Mark that we own the memory for these tensors + auto* out_et_tensor = reinterpret_cast(out_tensor_handle); + auto* attn_et_tensor = reinterpret_cast(attn_tensor_handle); + is_tensor_own_memory[out_et_tensor] = true; + is_tensor_own_memory[attn_et_tensor] = true; + + // Set output tensor handles + *ret0 = out_tensor_handle; + *ret1 = attn_tensor_handle; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph implementation completed successfully"); + } + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt index 3995f5533e6..866d17160ba 100644 --- a/examples/models/voxtral/CMakeLists.txt +++ b/examples/models/voxtral/CMakeLists.txt @@ -93,6 +93,11 @@ if(EXECUTORCH_BUILD_CUDA) executorch_target_link_options_shared_lib(aoti_cuda) endif() +if(EXECUTORCH_BUILD_METAL) + list(APPEND link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + # Add tokenizers list(APPEND link_libraries tokenizers::tokenizers) diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 78168a12aba..3f97db77ccc 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -63,6 +63,7 @@ set(optional_lib_list coreml_inmemoryfs coremldelegate mpsdelegate + metal_backend neuron_backend qnn_executorch_backend portable_ops_lib diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 04e84622589..861e41e4a63 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -152,6 +152,9 @@ define_overridable_option( define_overridable_option( EXECUTORCH_BUILD_CUDA "Build the CUDA backend" BOOL OFF ) +define_overridable_option( + EXECUTORCH_BUILD_METAL "Build the Metal backend" BOOL OFF +) define_overridable_option( EXECUTORCH_BUILD_VGF "Build the Arm VGF backend" BOOL OFF ) @@ -389,6 +392,10 @@ check_required_options_on( IF_ON EXECUTORCH_BUILD_CUDA REQUIRES EXECUTORCH_BUILD_EXTENSION_TENSOR ) +check_required_options_on( + IF_ON EXECUTORCH_BUILD_METAL REQUIRES EXECUTORCH_BUILD_EXTENSION_TENSOR +) + if(NOT EXISTS ${EXECUTORCH_PAL_DEFAULT_FILE_PATH}) message( FATAL_ERROR