diff --git a/CMakeLists.txt b/CMakeLists.txt index 7012ec641bf..f023069add6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -587,6 +587,16 @@ endif() if(EXECUTORCH_BUILD_CORTEX_M) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m) + list(APPEND _executorch_backends coretex_m_backend) +endif() + +if(EXECUTORCH_BUILD_CUDA) + # Build common AOTI functionality (required for CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti) + # Build CUDA-specific AOTI functionality + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda) + # Add aoti_cuda to backends - it already depends on aoti_common + list(APPEND _executorch_backends aoti_cuda) endif() if(EXECUTORCH_BUILD_EXTENSION_APPLE) @@ -1021,6 +1031,11 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) extension_runner_util gflags executorch_backends ) + # Add flat tensor extension if it's built + if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR) + list(APPEND _executor_runner_libs extension_flat_tensor) + endif() + if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib) elseif(EXECUTORCH_BUILD_CADENCE) diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 4b20aefc976..f7e42e2e58a 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -21,6 +21,7 @@ using executorch::runtime::etensor::Tensor; extern "C" { // Type definitions +using AOTITensorHandle = Tensor*; using AOTIRuntimeError = Error; // Forward declarations for AOT Inductor model container @@ -75,6 +76,7 @@ extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; struct AOTIDelegateHandle { void* so_handle; AOTInductorModelContainerHandle container_handle; + void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header dependency }; } // namespace aoti diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 1c872e08648..78c07bcea6e 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -34,6 +34,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { // Convert based on known PyTorch dtype codes (without CUDA-specific // dependency) switch (dtype) { + case 4: // PyTorch's int64 dtype code + return executorch::aten::ScalarType::Long; case 6: // PyTorch's float32 dtype code return executorch::aten::ScalarType::Float; case 15: // PyTorch's bfloat16 dtype code diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt new file mode 100644 index 00000000000..30e307bba99 --- /dev/null +++ b/backends/cuda/CMakeLists.txt @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI CUDA backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +find_package(CUDAToolkit REQUIRED) + +# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# CUDA-specific AOTI functionality +set(_aoti_cuda_sources + runtime/cuda_backend.cpp + runtime/guard.cpp + runtime/shims/cuda_guard.cpp + runtime/shims/memory.cpp + runtime/shims/tensor_attribute.cpp +) +add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) +target_include_directories( + aoti_cuda + PUBLIC ${CUDAToolkit_INCLUDE_DIRS} + $ + $ + # PyTorch AOTI headers from ExecutorTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) +target_compile_options(aoti_cuda PUBLIC -fexceptions -frtti -fPIC) +# Ensure symbols are exported properly +target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) + +# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries +target_link_libraries( + aoti_cuda + PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS} + # Link PyTorch libraries for AOTI CUDA functions + ${TORCH_LIBRARIES} +) +# If you need other CUDA libraries, link them similarly: +# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) +executorch_target_link_options_shared_lib(aoti_cuda) + +# Add runtime +add_executable(voxtral_runner tests/voxtral_runner.cpp) +target_link_libraries( + voxtral_runner PUBLIC aoti_cuda extension_module_static extension_flat_tensor + portable_ops_lib +) + +install( + TARGETS aoti_cuda + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 8ed8cdefbb1..a72538d3471 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -33,11 +33,9 @@ # required fallback kernels but not supported missing_fallback_kernels: Set[str] = set() - class COMPILE_SPEC_KEYS(Enum): METHOD_NAME = "method_name" - # context manager for non-fallback guarantee # it will raise exception when generating fallback kernels during aoti compile @contextlib.contextmanager diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 1aa38760e5a..0386b5a008d 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -5,13 +5,17 @@ oncall("executorch") runtime.cxx_library( name = "runtime_shims", srcs = [ + "guard.cpp", + "shims/cuda_guard.cpp", "shims/memory.cpp", "shims/tensor_attribute.cpp", ], headers = [ + "guard.h", + "shims/cuda_guard.h", "shims/memory.h", "shims/tensor_attribute.h", - "shims/utils.h", + "utils.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp new file mode 100644 index 00000000000..680923fa590 --- /dev/null +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// Include our shim layer headers +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using namespace std; +using namespace aoti; + +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::etensor::Tensor; + +class CudaBackend final : public ::executorch::runtime::BackendInterface { + private: + Error register_shared_library_functions(void* so_handle) const { + AOTInductorModelContainerCreateWithDevice = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice")); + if (AOTInductorModelContainerCreateWithDevice == nullptr) { + ET_LOG(Error, "Failed to load AOTInductorModelContainerCreateWithDevice"); + return Error::AccessFailed; + } + + AOTInductorModelContainerDelete = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerDelete")); + if (AOTInductorModelContainerDelete == nullptr) { + ET_LOG(Error, "Failed to load AOTInductorModelContainerDelete"); + return Error::AccessFailed; + } + + AOTInductorModelContainerGetNumInputs = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetNumInputs")); + if (AOTInductorModelContainerGetNumInputs == nullptr) { + ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumInputs"); + return Error::AccessFailed; + } + + AOTInductorModelContainerGetNumOutputs = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs")); + if (AOTInductorModelContainerGetNumOutputs == nullptr) { + ET_LOG(Error, "Failed to load AOTInductorModelContainerGetNumOutputs"); + return Error::AccessFailed; + } + + AOTInductorModelContainerRun = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerRun")); + if (AOTInductorModelContainerRun == nullptr) { + ET_LOG(Error, "Failed to load AOTInductorModelContainerRun"); + return Error::AccessFailed; + } + + return Error::Ok; + } + + public: + bool is_available() const override { + return 1; + } + + // Once per loaded binary blob + Result init( + BackendInitContext& context, + FreeableBuffer* processed, // This will be a empty buffer + ArrayRef compile_specs // This will be my empty list + ) const override { + std::string method_name; + for (const CompileSpec& spec : compile_specs) { + if (std::strcmp(spec.key, "method_name") == 0) { + method_name.assign( + static_cast(spec.value.buffer), + spec.value.nbytes); // no nullptr guarantee, so pass size + break; + } + } + + std::string so_blob_key = + method_name.empty() ? "so_blob" : method_name + "_so_blob"; + + const NamedDataMap* named_data_map = context.get_named_data_map(); + auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str()); + if (!aoti_cuda_buffer.ok()) { + ET_LOG( + Error, + "Failed to get data for key %s: 0x%x", + so_blob_key.c_str(), + aoti_cuda_buffer.error()); + return aoti_cuda_buffer.error(); + } + // Generate dynamic temporary file path + filesystem::path temp_dir = filesystem::temp_directory_path(); + filesystem::path so_path = + temp_dir / (so_blob_key + to_string(getpid()) + ".so"); + + // Create a temporary file + ofstream outfile(so_path.c_str(), ios::binary); + + // Write the ELF buffer to the temporary file + ET_LOG( + Info, + "Writing %zu bytes to %s", + aoti_cuda_buffer->size(), + so_path.c_str()); + outfile.write( + static_cast(aoti_cuda_buffer->data()), + aoti_cuda_buffer->size()); + + // Finish writing the file to disk + outfile.close(); + + // Load the ELF using dlopen + void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (so_handle == nullptr) { + ET_LOG(Error, "Failed to load shared library: %s", dlerror()); + return Error::AccessFailed; + } + + processed->Free(); + + // Register all shared library functions + Error reg_err = register_shared_library_functions(so_handle); + if (reg_err != Error::Ok) { + return reg_err; + } + + AOTInductorModelContainerHandle container_handle = nullptr; + + AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice( + &container_handle, 1, "cuda", nullptr); + if (err != Error::Ok) { + return err; + } + ET_LOG(Info, "container_handle = %p", container_handle); + + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = so_handle; + handle->container_handle = container_handle; + + // Create a CUDA stream for asynchronous execution + cudaStream_t cuda_stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); + handle->cuda_stream = static_cast(cuda_stream); + ET_LOG(Info, "Created CUDA stream: %p", handle->cuda_stream); + + return (DelegateHandle*)handle; // Return the handle post-processing + } + + // Once per execution + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle_, + Span args) const override { + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + size_t n_inputs; + AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs); + + size_t n_outputs; + AOTInductorModelContainerGetNumOutputs( + handle->container_handle, &n_outputs); + + if (n_inputs + n_outputs != args.size()) { + ET_LOG( + Error, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()); + return Error::InvalidArgument; + } + + // NOTE: ExecutorTorch tensors are always on CPU/host memory + // We need to create GPU copies for CUDA kernel execution + std::vector gpu_inputs( + n_inputs); // GPU copies for kernel execution + std::vector gpu_outputs( + n_outputs); // GPU tensors for kernel output + + // Process input tensors: ExecutorTorch provides CPU tensors, create GPU + // copies + for (int i = 0; i < n_inputs; i++) { + // Get tensor dimensions and properties from ExecutorTorch CPU tensor + auto cpu_tensor = &(args[i]->toTensor()); + auto sizes = cpu_tensor->sizes(); + auto scalar_type = cpu_tensor->scalar_type(); + + // Create GPU tensor with same shape + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_input_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_input_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for input %d", i); + return Error::Internal; + } + + gpu_inputs[i] = gpu_input_handle; + + // Copy data from CPU to GPU + Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); + return Error::Internal; + } + } + ET_LOG(Info, "Inputs copied to GPU"); + // Process output tensors: create GPU counterparts for ExecutorTorch CPU + // tensors + for (int i = 0; i < n_outputs; i++) { + // Get output tensor dimensions from ExecutorTorch CPU tensor + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = cpu_output_tensor->sizes(); + auto scalar_type = cpu_output_tensor->scalar_type(); + + // Create GPU tensor with same shape for kernel output + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_output_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_output_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for output %d", i); + return Error::Internal; + } + + gpu_outputs[i] = gpu_output_handle; + } + ET_LOG(Info, "Outputs created on GPU"); + // Run AOTI container with GPU tensors + AOTIRuntimeError error = AOTInductorModelContainerRun( + handle->container_handle, + gpu_inputs.data(), // Use GPU input tensors + n_inputs, + gpu_outputs.data(), // Use GPU output tensors + n_outputs, + handle->cuda_stream, // Pass the actual CUDA stream + nullptr); // proxy_executor_handle can remain nullptr + + if (error != Error::Ok) { + ET_LOG( + Error, + "AOTInductorModelContainerRun failed with error code %d", + error); + return Error::Internal; + } + + // Copy GPU output results back to CPU output tensors + for (int i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + // For DYNAMIC_BOUND tensors we try to resize + ET_CHECK_OK_OR_RETURN_ERROR( + resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()), + "Error resizing tensor at output index %d", + i); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0), + "Failed to copy GPU output %d back to CPU", + i); + } + + // // Clean up GPU tensors that we created (ExecutorTorch tensors are always + // // CPU, so all GPU tensors are our copies) + // for (int i = 0; i < n_inputs; i++) { + // ET_LOG(Info, "Deleting GPU input tensor %d", i); + // // All GPU input tensors were created by us, delete them + // aoti_torch_delete_tensor_object(gpu_inputs[i]); + // } + + for (int i = 0; i < n_outputs; i++) { + ET_LOG(Info, "Deleting GPU output tensor %d", i); + // All GPU output tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_outputs[i]); + } + + return Error::Ok; + } + + void destroy(DelegateHandle* handle_) const override { + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + // Destroy the CUDA stream if it exists + if (handle->cuda_stream != nullptr) { + cudaStream_t cuda_stream = static_cast(handle->cuda_stream); + cudaError_t stream_err = cudaStreamDestroy(cuda_stream); + if (stream_err != cudaSuccess) { + ET_LOG( + Error, + "Failed to destroy CUDA stream: %s", + cudaGetErrorString(stream_err)); + } else { + ET_LOG(Info, "Destroyed CUDA stream: %p", handle->cuda_stream); + } + handle->cuda_stream = nullptr; + } + + // Delete the container BEFORE closing the shared library + // if (handle->container_handle != nullptr) { + // ET_LOG(Info, "Deleting container_handle: %p",handle->container_handle); + // AOTIRuntimeError delete_result = + // AOTInductorModelContainerDelete(handle->container_handle); + // if (delete_result != Error::Ok) { + // ET_LOG( + // Error, + // "AOTInductorModelContainerDelete failed with error code %d", + // delete_result); + // } + // handle->container_handle = nullptr; + // } + + ET_LOG(Info, "Deleted container_handle: %p", handle->container_handle); + + // Now close the shared library + if (handle->so_handle != nullptr) { + dlclose(handle->so_handle); + handle->so_handle = nullptr; + } + + ET_LOG(Info, "Deleted so_handle: %p", handle->so_handle); + + free(handle); + + ET_LOG(Info, "Deleted AOTI delegate handle: %p", handle); + + clear_all_tensors(); + + ET_LOG(Info, "Deleted all tensors"); + } +}; + +} // namespace cuda + +namespace { +auto cls = cuda::CudaBackend(); +executorch::runtime::Backend backend{"CudaBackend", &cls}; +static executorch::runtime::Error success_with_compiler = + register_backend(backend); +} // namespace + +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/guard.cpp b/backends/cuda/runtime/guard.cpp new file mode 100644 index 00000000000..36c541e1770 --- /dev/null +++ b/backends/cuda/runtime/guard.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +namespace { +// Thread-local stream storage (private to this file) +thread_local std::unordered_map current_streams_; +} // namespace + +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) { + if (device_index == -1) { + // Get current device if not specified + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + current_streams_[device_index] = stream; + return Error::Ok; +} + +Result getCurrentCUDAStream(DeviceIndex device_index) { + if (device_index == -1) { + int current_device; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(¤t_device)); + device_index = current_device; + } + + auto it = current_streams_.find(device_index); + if (it != current_streams_.end()) { + return it->second; + } + + cudaStream_t stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept + : original_device_index_(other.original_device_index_), + current_device_index_(other.current_device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the device + other.original_device_index_ = other.current_device_index_; +} + +CUDAGuard::~CUDAGuard() { + if (original_device_index_ != current_device_index_) { + cudaError_t err = cudaSetDevice(original_device_index_); + if (err != cudaSuccess) { + ET_LOG( + Error, + "~CUDAGuard: Failed to restore device to %d: %s", + original_device_index_, + cudaGetErrorString(err)); + } + } +} + +Error CUDAGuard::set_index(DeviceIndex device_index) { + int orig_index = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + + if (current_device_index_ != original_device_index_) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_)); + } + + return Error::Ok; +} + +Result CUDAGuard::create(DeviceIndex device_index) { + CUDAGuard guard; // Fixed: Removed () to create a variable, not a function + ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index)); + return guard; +} + +CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept + : device_guard_(std::move(other.device_guard_)), + original_stream_(other.original_stream_), + current_stream_(other.current_stream_), + device_index_(other.device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the stream + other.original_stream_ = other.current_stream_; +} + +CUDAStreamGuard::~CUDAStreamGuard() { + if (original_stream_ != nullptr) { + Error err = setCurrentCUDAStream(original_stream_, device_index_); + if (err != Error::Ok) { + ET_LOG( + Error, + "~CUDAStreamGuard: Failed to restore stream for device %d", + device_index_); + } + } +} + +Error CUDAStreamGuard::set_stream( + cudaStream_t stream, + DeviceIndex device_index) { + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + ET_LOG(Error, "Failed to get current stream for device %d", device_index); + return result.error(); + } + + original_stream_ = result.get(); + current_stream_ = stream; + device_index_ = device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index)); + + return Error::Ok; +} + +Result CUDAStreamGuard::create( + cudaStream_t stream, + DeviceIndex device_index) { + auto guard_result = CUDAGuard::create(device_index); + ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error()); + + CUDAStreamGuard stream_guard(std::move(guard_result.get())); + ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index)); + + return stream_guard; +} + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/guard.h b/backends/cuda/runtime/guard.h new file mode 100644 index 00000000000..d421315ac1d --- /dev/null +++ b/backends/cuda/runtime/guard.h @@ -0,0 +1,200 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; + +// Type alias for device index +using DeviceIndex = int32_t; + +/** + * Set the current CUDA stream for the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index (-1 to use current device) + * @return Error code indicating success or failure + */ +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream for the specified device. + * If no stream has been set, creates a new stream and sets it as current. + * + * @param device_index The device index (-1 to use current device) + * @return Result containing the current stream on success, or an error code on + * failure + */ +Result getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * RAII guard that sets the current CUDA device and restores it on destruction. + * This ensures that the device is properly restored even if an exception + * occurs. + * + * NOTE: Do not use constructors directly. Use the create() factory method + * instead. + */ +class CUDAGuard { + private: + /** + * Private constructor - use create() factory method instead. + */ + explicit CUDAGuard() + : original_device_index_(-1), current_device_index_(-1){}; + + public: + /** + * Factory method to create a CUDAGuard. + * + * @param device_index The device index to set as current + * @return Result containing the guard on success, or an error code on failure + */ + static Result create(DeviceIndex device_index); + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move constructor and assignment + CUDAGuard(CUDAGuard&& other) noexcept; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /** + * Destructor that restores the original device if necessary. + */ + ~CUDAGuard(); + + /** + * Sets the CUDA device to the given device index. + * + * @param device_index The device index to set as current + * @return Error code indicating success or failure + */ + Error set_index(DeviceIndex device_index); + + /** + * Get the original device index before the guard was created. + * + * @return The original device index + */ + DeviceIndex original_device() const { + return original_device_index_; + } + + /** + * Get the current device index. + * + * @return The current device index + */ + DeviceIndex current_device() const { + return current_device_index_; + } + + private: + /// The original device before this guard was created + DeviceIndex original_device_index_; + /// The current device managed by this guard + DeviceIndex current_device_index_; +}; + +/** + * RAII guard that sets the current CUDA device and stream, restoring both on + * destruction. This is useful for temporarily switching to a different device + * and stream. + * + * NOTE: Do not use constructors directly. Use the create() factory method + * instead. + */ +class CUDAStreamGuard { + private: + // Private constructor that takes a CUDAGuard + explicit CUDAStreamGuard(CUDAGuard&& guard) + : device_guard_(std::move(guard)), + original_stream_(nullptr), + current_stream_(nullptr), + device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAStreamGuard. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Result containing the guard on success, or an error code on failure + */ + static Result create( + cudaStream_t stream, + DeviceIndex device_index); + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move constructor and assignment + CUDAStreamGuard(CUDAStreamGuard&& other) noexcept; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete; + + /** + * Destructor that restores the original stream and device. + */ + ~CUDAStreamGuard(); + + /** + * Sets the CUDA stream to the given stream on the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Error code indicating success or failure + */ + Error set_stream(cudaStream_t stream, DeviceIndex device_index); + + /** + * Get the current guarded stream. + * + * @return The current stream + */ + cudaStream_t stream() const { + return current_stream_; + } + + /** + * Get the device index being guarded. + * + * @return The device index + */ + DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + DeviceIndex device_index_; +}; + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/cuda_guard.cpp b/backends/cuda/runtime/shims/cuda_guard.cpp new file mode 100644 index 00000000000..5740d0bf654 --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace cuda { + +extern "C" { + +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_guard failed: ret_guard is null"); + + auto result = CUDAGuard::create(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_cuda_guard_set_index failed: guard is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(guard->set_index(device_index)); + return Error::Ok; +} + +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard) { + ET_CHECK_OR_RETURN_ERROR( + ret_guard != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: ret_guard is null"); + + ET_CHECK_OR_RETURN_ERROR( + stream != nullptr, + InvalidArgument, + "aoti_torch_create_cuda_stream_guard failed: stream is null"); + + auto result = + CUDAStreamGuard::create(static_cast(stream), device_index); + if (!result.ok()) { + return result.error(); + } + *ret_guard = new CUDAStreamGuard(std::move(result.get())); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_cuda_stream_guard( + CUDAStreamGuardHandle guard) { + ET_CHECK_OR_RETURN_ERROR( + guard != nullptr, + InvalidArgument, + "aoti_torch_delete_cuda_stream_guard failed: guard is null"); + + delete guard; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream) { + ET_CHECK_OR_RETURN_ERROR( + ret_stream != nullptr, + InvalidArgument, + "aoti_torch_get_current_cuda_stream failed: ret_stream is null"); + + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + return result.error(); + } + *ret_stream = static_cast(result.get()); + return Error::Ok; +} + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h new file mode 100644 index 00000000000..6da869064a7 --- /dev/null +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +using executorch::backends::aoti::AOTITorchError; + +extern "C" { + +// Handle types for CUDA guards +using CUDAGuardHandle = CUDAGuard*; +using CUDAStreamGuardHandle = CUDAStreamGuard*; + +/** + * Creates a CUDA device guard that sets the current device and restores it + * upon destruction. + * + * @param device_index The device index to set as current + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard); + +/** + * Deletes a CUDA device guard and frees its associated resources. + * + * @param guard Handle to the guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); + +/** + * Sets the CUDA device to a new index for an existing guard. + * + * @param guard Handle to the guard + * @param device_index The device index to set as current + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_cuda_guard_set_index( + CUDAGuardHandle guard, + int32_t device_index); + +/** + * Creates a CUDA stream guard that sets the current device and stream, + * restoring both upon destruction. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @param ret_guard Output parameter for the created guard handle (must not be + * null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard); + +/** + * Deletes a CUDA stream guard and frees its associated resources. + * + * @param guard Handle to the stream guard to be deleted + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); + +/** + * Gets the current CUDA stream for a specified device. + * + * @param device_index The device index (-1 to use current device) + * @param ret_stream Output parameter for the current stream (must not be null) + * @return AOTITorchError error code (Error::Ok on success, or an error code on + * failure) + */ +AOTITorchError aoti_torch_get_current_cuda_stream( + int32_t device_index, + void** ret_stream); + +} // extern "C" + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 2b32d820301..4350cac0ff8 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include // For posix_memalign @@ -271,10 +271,16 @@ void clear_all_tensors() { // Use aoti_torch_delete_tensor_object to properly delete each tensor // Note: We need to collect tensor pointers first since deletion modifies the // set - auto old_tensors = - std::move(tensors); // tensors is now empty and no need to copy - for (const auto& tensor_shared : old_tensors) { - aoti_torch_delete_tensor_object(tensor_shared.get()); + ET_LOG(Info, "Clearing all tensors..."); + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } // tensors set should now be empty, but ensure it's cleared @@ -308,42 +314,48 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { // Find the reference count for this memory address auto memory_it = memory_to_n_tensor.find(data_ptr); - if (memory_it != memory_to_n_tensor.end()) { - int32_t ref_count = memory_it->second; - - if (ref_count == NOT_OWN) { - // Tensor never owned the memory, skip freeing - // Just remove tensor from tracking - tensors.erase(it); - return Error::Ok; - } else if (ref_count == 1) { - // Only current tensor using this memory, free it - // Determine if it's GPU memory - cudaPointerAttributes attributes{}; - ET_CUDA_CHECK_OR_RETURN_ERROR( - cudaPointerGetAttributes(&attributes, data_ptr)); - - if (attributes.type == cudaMemoryTypeManaged) { - // This is CUDA managed memory - free with proper synchronization - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize()); - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); - } else { - // This is CPU memory - free immediately - free(data_ptr); - data_ptr = nullptr; - } - - // Remove from memory tracking - memory_to_n_tensor.erase(memory_it); - } else if (ref_count > 1) { - // Other tensors still using this memory, just decrement count - memory_to_n_tensor[data_ptr] = ref_count - 1; + + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + Internal, + "Internal error: memory not found during deletion"); + + int32_t ref_count = memory_it->second; + + ET_CHECK_OR_RETURN_ERROR( + ref_count >= 0 || ref_count == NOT_OWN, + Internal, + "Internal error: invalid ref count %d", + ref_count) + + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking + tensors.erase(it); + return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Determine if it's GPU memory + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeManaged) { + // This is CUDA managed memory - free with proper synchronization + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaDeviceSynchronize()); + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr)); + } else { + // This is CPU memory - free immediately + free(data_ptr); + data_ptr = nullptr; } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); } else { - ET_CHECK_OR_RETURN_ERROR( - false, - Internal, - "Internal error: memory not found during deletion"); + // ref_count > 1 + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; } // Remove tensor from set (this will call the destructor if it's the last @@ -379,7 +391,6 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { aoti_torch_get_dtype(src, &src_dtype); ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype)); - ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype)); // Check dtype compatibility - both tensors must have the same dtype diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp index ef00ecff656..e18bf142b5c 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp index 7579eaef039..9fca0f92cf8 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp index 2cb12719782..d9b785a5a78 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp index eceb141e9ca..10c8d8c1a31 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_delete_tensor_object.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp index 8e6998f457c..da65129f18a 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/backends/cuda/runtime/shims/utils.h b/backends/cuda/runtime/utils.h similarity index 94% rename from backends/cuda/runtime/shims/utils.h rename to backends/cuda/runtime/utils.h index 99d2bc102f5..02c3abfc83f 100644 --- a/backends/cuda/runtime/shims/utils.h +++ b/backends/cuda/runtime/utils.h @@ -40,6 +40,7 @@ namespace cuda { // Enum for supported data types in et-cuda backend enum class SupportedDTypes : int32_t { + INT64 = 4, // PyTorch's int64 dtype code FLOAT32 = 6, // PyTorch's float32 dtype code BFLOAT16 = 15, // PyTorch's bfloat16 dtype code }; @@ -100,6 +101,7 @@ using AOTITorchError = Error; // Helper function to check if a dtype is supported in ET CUDA backend inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { switch (dtype) { + case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): case static_cast(SupportedDTypes::BFLOAT16): return true; @@ -113,8 +115,9 @@ inline AOTITorchError validate_dtype(int32_t dtype) { ET_CHECK_OR_RETURN_ERROR( is_dtype_supported_in_et_cuda(dtype), InvalidArgument, - "Unsupported dtype: %d. Supported dtypes: %d (float32), %d (bfloat16)", + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", dtype, + static_cast(SupportedDTypes::INT64), static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16)); diff --git a/backends/cuda/tests/voxtral_runner.cpp b/backends/cuda/tests/voxtral_runner.cpp new file mode 100644 index 00000000000..feed458e1f5 --- /dev/null +++ b/backends/cuda/tests/voxtral_runner.cpp @@ -0,0 +1,264 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::extension::make_tensor_ptr; +using executorch::extension::TensorPtr; +using executorch::extension::module::Module; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::Result; +using Clock = std::chrono::steady_clock; +using DurationMs = std::chrono::duration; + +std::vector to_sizes( + std::initializer_list dims) { + return std::vector(dims.begin(), dims.end()); +} + +std::string format_shape(const Tensor& tensor) { + std::ostringstream oss; + oss << "["; + const auto& sizes = tensor.sizes(); + for (size_t i = 0; i < sizes.size(); ++i) { + if (i > 0) { + oss << ", "; + } + oss << sizes[i]; + } + oss << "]"; + return oss.str(); +} + +void print_tensor_summary(const std::string& label, const Tensor& tensor) { + std::cout << " " << label + << ": dtype=" << executorch::runtime::toString(tensor.scalar_type()) + << ", shape=" << format_shape(tensor) + << ", numel=" << tensor.numel() << std::endl; +} + +TensorPtr create_audio_input() { + const auto sizes = to_sizes({3, 128, 3000}); + const size_t numel = 3ull * 128ull * 3000ull; + std::vector data(numel, 0.5f); + return make_tensor_ptr( + sizes, std::move(data), {}, {}, ScalarType::BFloat16); +} + +TensorPtr create_token_ids_input() { + const auto sizes = to_sizes({1, 1138}); + std::vector data(static_cast(1) * 1138, 0); + return make_tensor_ptr(sizes, std::move(data)); +} + +TensorPtr create_positions_input() { + const auto sizes = to_sizes({1138}); + std::vector data(static_cast(1138), 0); + return make_tensor_ptr(sizes, std::move(data)); +} + +TensorPtr create_fallback_text_embedding() { + const auto sizes = to_sizes({1, 1138, 3072}); + const size_t numel = 1ull * 1138ull * 3072ull; + std::vector data(numel, 0.0f); + return make_tensor_ptr( + sizes, std::move(data), {}, {}, ScalarType::BFloat16); +} + +struct MethodTiming { + double load_ms{0.0}; + double run_ms{0.0}; +}; + +} // namespace + +int main(int argc, char** argv) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] + << " " + << std::endl; + return 1; + } + + const std::string program_path = argv[1]; + const std::string data_map_path = argv[2]; + + try { + Module module(program_path, data_map_path); + + const auto program_load_start = Clock::now(); + const Error program_load_error = module.load(); + const auto program_load_end = Clock::now(); + if (program_load_error != Error::Ok) { + std::cerr << "Failed to load ExecuTorch program: error code " + << static_cast(program_load_error) << std::endl; + return 1; + } + const DurationMs program_load_latency = + program_load_end - program_load_start; + + MethodTiming audio_timing; + MethodTiming token_timing; + MethodTiming text_timing; + + auto measure_method_load = + [&](const std::string& name) -> std::pair { + const auto start = Clock::now(); + const Error err = module.load_method(name); + const auto end = Clock::now(); + return {err, DurationMs(end - start).count()}; + }; + + // audio_encoder + { + const auto [err, load_ms] = measure_method_load("audio_encoder"); + if (err != Error::Ok) { + std::cerr << "Failed to load method audio_encoder: error code " + << static_cast(err) << std::endl; + return 1; + } + audio_timing.load_ms = load_ms; + + const TensorPtr audio_input = create_audio_input(); + std::vector inputs; + std::vector owned_inputs; + owned_inputs.emplace_back(audio_input); + inputs.emplace_back(*audio_input); + + const auto run_start = Clock::now(); + Result> output_result = + module.execute("audio_encoder", inputs); + const auto run_end = Clock::now(); + audio_timing.run_ms = DurationMs(run_end - run_start).count(); + + if (output_result.error() != Error::Ok) { + std::cerr << "audio_encoder execution failed: error code " + << static_cast(output_result.error()) << std::endl; + return 1; + } + + const auto& outputs = output_result.get(); + if (!outputs.empty() && outputs[0].isTensor()) { + print_tensor_summary("audio_encoder output", outputs[0].toTensor()); + } + } + + EValue token_output; + bool token_executed = false; + + // token_embedding + { + const auto [err, load_ms] = measure_method_load("token_embedding"); + if (err != Error::Ok) { + std::cerr << "Failed to load method token_embedding: error code " + << static_cast(err) << std::endl; + return 1; + } + token_timing.load_ms = load_ms; + + const TensorPtr token_ids = create_token_ids_input(); + std::vector inputs; + std::vector owned_inputs; + owned_inputs.emplace_back(token_ids); + inputs.emplace_back(*token_ids); + + const auto run_start = Clock::now(); + auto token_output_result = module.execute("token_embedding", inputs); + const auto run_end = Clock::now(); + token_timing.run_ms = DurationMs(run_end - run_start).count(); + + if (token_output_result.error() != Error::Ok) { + std::cerr << "token_embedding execution failed: error code " + << static_cast(token_output_result.error()) << std::endl; + return 1; + } + + token_executed = true; + const auto& outputs = token_output_result.get(); + if (!outputs.empty() && outputs[0].isTensor()) { + print_tensor_summary("token_embedding output", outputs[0].toTensor()); + token_output = outputs[0]; + } + } + + // text_decoder + { + const auto [err, load_ms] = measure_method_load("text_decoder"); + if (err != Error::Ok) { + std::cerr << "Failed to load method text_decoder: error code " + << static_cast(err) << std::endl; + return 1; + } + text_timing.load_ms = load_ms; + + std::vector inputs; + std::vector owned_inputs; + if (token_executed) { + if (token_output.isTensor()) { + inputs.emplace_back(token_output); + } + } + + if (inputs.empty()) { + auto fallback_embedding = create_fallback_text_embedding(); + owned_inputs.emplace_back(fallback_embedding); + inputs.emplace_back(*fallback_embedding); + } + + auto positions = create_positions_input(); + owned_inputs.emplace_back(positions); + inputs.emplace_back(*positions); + + const auto run_start = Clock::now(); + Result> output_result = + module.execute("text_decoder", inputs); + const auto run_end = Clock::now(); + text_timing.run_ms = DurationMs(run_end - run_start).count(); + + if (output_result.error() != Error::Ok) { + std::cerr << "text_decoder execution failed: error code " + << static_cast(output_result.error()) << std::endl; + return 1; + } + + const auto& outputs = output_result.get(); + if (!outputs.empty() && outputs[0].isTensor()) { + print_tensor_summary("text_decoder output", outputs[0].toTensor()); + } + } + + std::cout << std::fixed << std::setprecision(3); + std::cout << "Program load latency (ms): " << program_load_latency.count() + << std::endl; + + std::cout << "Method load latency (ms):" << std::endl; + std::cout << " audio_encoder: " << audio_timing.load_ms << std::endl; + std::cout << " token_embedding: " << token_timing.load_ms << std::endl; + std::cout << " text_decoder: " << text_timing.load_ms << std::endl; + + std::cout << "Run latency (ms):" << std::endl; + std::cout << " audio_encoder: " << audio_timing.run_ms << std::endl; + std::cout << " token_embedding: " << token_timing.run_ms << std::endl; + std::cout << " text_decoder: " << text_timing.run_ms << std::endl; + + return 0; + } catch (const std::exception& ex) { + std::cerr << "Unhandled exception: " << ex.what() << std::endl; + return 1; + } +} diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index bcc6aba0f8e..08051515d8d 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -644,4 +644,4 @@ PYBIND11_MODULE(_llm_runner, m) { .def("__repr__", [](const PyMultimodalRunner& runner) { return ""; }); -} \ No newline at end of file +} diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index fb0dc0a4ade..32043d4d427 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -129,6 +129,9 @@ define_overridable_option( define_overridable_option( EXECUTORCH_BUILD_SIZE_TEST "Build the size test" BOOL OFF ) +define_overridable_option( + EXECUTORCH_BUILD_CUDA "Build the CUDA backend" BOOL OFF +) define_overridable_option( EXECUTORCH_BUILD_XNNPACK "Build the XNNPACK backend" BOOL OFF )