diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 8c9330d6f2c..e3a53c8bcb5 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -4d4abec80f03cd8fdefe1d9cb3a60d3690cd777e +53a2908a10f414a2f85caa06703a26a40e873869 diff --git a/.ci/scripts/test-cuda-build.sh b/.ci/scripts/test-cuda-build.sh new file mode 100755 index 00000000000..a9f8e7ec14f --- /dev/null +++ b/.ci/scripts/test-cuda-build.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +CUDA_VERSION=${1:-"12.6"} + +echo "=== Testing ExecutorTorch CUDA ${CUDA_VERSION} Build ===" + +# Function to build and test ExecutorTorch with CUDA support +test_executorch_cuda_build() { + local cuda_version=$1 + + echo "Building ExecutorTorch with CUDA ${cuda_version} support..." + echo "ExecutorTorch will automatically detect CUDA and install appropriate PyTorch wheel" + + # Check available resources before starting + echo "=== System Information ===" + echo "Available memory: $(free -h | grep Mem | awk '{print $2}')" + echo "Available disk space: $(df -h . | tail -1 | awk '{print $4}')" + echo "CPU cores: $(nproc)" + echo "CUDA version check:" + nvcc --version || echo "nvcc not found" + nvidia-smi || echo "nvidia-smi not found" + + # Set CMAKE_ARGS to enable CUDA build - ExecutorTorch will handle PyTorch installation automatically + export CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + + echo "=== Starting ExecutorTorch Installation ===" + # Install ExecutorTorch with CUDA support with timeout and error handling + timeout 5400 ./install_executorch.sh || { + local exit_code=$? + echo "ERROR: install_executorch.sh failed with exit code: $exit_code" + if [ $exit_code -eq 124 ]; then + echo "ERROR: Installation timed out after 90 minutes" + fi + exit $exit_code + } + + echo "SUCCESS: ExecutorTorch CUDA build completed" + + # Verify the installation + echo "=== Verifying ExecutorTorch CUDA Installation ===" + + # Test that ExecutorTorch was built successfully + python -c " +import executorch +print('SUCCESS: ExecutorTorch imported successfully') +" + + # Test CUDA availability and show details + python -c " +try: + import torch + print('INFO: PyTorch version:', torch.__version__) + print('INFO: CUDA available:', torch.cuda.is_available()) + + if torch.cuda.is_available(): + print('SUCCESS: CUDA is available for ExecutorTorch') + print('INFO: CUDA version:', torch.version.cuda) + print('INFO: GPU device count:', torch.cuda.device_count()) + print('INFO: Current GPU device:', torch.cuda.current_device()) + print('INFO: GPU device name:', torch.cuda.get_device_name()) + + # Test basic CUDA tensor operation + device = torch.device('cuda') + x = torch.randn(10, 10).to(device) + y = torch.randn(10, 10).to(device) + z = torch.mm(x, y) + print('SUCCESS: CUDA tensor operation completed on device:', z.device) + print('INFO: Result tensor shape:', z.shape) + + print('SUCCESS: ExecutorTorch CUDA integration verified') + else: + print('WARNING: CUDA not detected, but ExecutorTorch built successfully') + exit(1) +except Exception as e: + print('ERROR: ExecutorTorch CUDA test failed:', e) + exit(1) +" + + echo "SUCCESS: ExecutorTorch CUDA ${cuda_version} build and verification completed successfully" +} + +# Main execution +echo "Current working directory: $(pwd)" +echo "Directory contents:" +ls -la + +# Run the CUDA build test +test_executorch_cuda_build "${CUDA_VERSION}" diff --git a/.ci/scripts/test-cuda-export-aoti.sh b/.ci/scripts/test-cuda-export-aoti.sh new file mode 100755 index 00000000000..6ea701b8f4b --- /dev/null +++ b/.ci/scripts/test-cuda-export-aoti.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +CUDA_VERSION=${1:-"12.6"} + +echo "=== Testing ExecutorTorch CUDA AOTI Export ${CUDA_VERSION} ===" + +# Function to test CUDA AOTI export functionality +test_cuda_aoti_export() { + local cuda_version=$1 + + echo "Testing CUDA AOTI export with CUDA ${cuda_version} support..." + + # Check available resources before starting + echo "=== System Information ===" + echo "Available memory: $(free -h | grep Mem | awk '{print $2}')" + echo "Available disk space: $(df -h . | tail -1 | awk '{print $4}')" + echo "CPU cores: $(nproc)" + echo "CUDA version check:" + nvcc --version || echo "nvcc not found" + nvidia-smi || echo "nvidia-smi not found" + + # Set up environment for CUDA builds + export CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + + echo "=== Installing ExecutorTorch with CUDA support ===" + # Install ExecutorTorch with CUDA support with timeout and error handling + timeout 5400 ./install_executorch.sh || { + local exit_code=$? + echo "ERROR: install_executorch.sh failed with exit code: $exit_code" + if [ $exit_code -eq 124 ]; then + echo "ERROR: Installation timed out after 90 minutes" + fi + exit $exit_code + } + + echo "SUCCESS: ExecutorTorch CUDA installation completed" + + # Verify the installation + echo "=== Verifying ExecutorTorch CUDA Installation ===" + + # Test that ExecutorTorch was built successfully + python -c " +import executorch +print('SUCCESS: ExecutorTorch imported successfully') +" + + # Test CUDA availability and show details + python -c " +try: + import torch + print('INFO: PyTorch version:', torch.__version__) + print('INFO: CUDA available:', torch.cuda.is_available()) + + if torch.cuda.is_available(): + print('SUCCESS: CUDA is available for ExecutorTorch') + print('INFO: CUDA version:', torch.version.cuda) + print('INFO: GPU device count:', torch.cuda.device_count()) + print('INFO: Current GPU device:', torch.cuda.current_device()) + print('INFO: GPU device name:', torch.cuda.get_device_name()) + + # Test basic CUDA tensor operation + device = torch.device('cuda') + x = torch.randn(10, 10).to(device) + y = torch.randn(10, 10).to(device) + z = torch.mm(x, y) + print('SUCCESS: CUDA tensor operation completed on device:', z.device) + print('INFO: Result tensor shape:', z.shape) + + print('SUCCESS: ExecutorTorch CUDA integration verified') + else: + print('WARNING: CUDA not detected, but ExecutorTorch built successfully') + exit(1) +except Exception as e: + print('ERROR: ExecutorTorch CUDA test failed:', e) + exit(1) +" + + echo "=== Running CUDA AOTI Export Tests ===" + # Run the CUDA AOTI export tests using the Python script + python .ci/scripts/test_cuda_export_aoti.py \ + --models linear conv2d add resnet18 \ + --export-mode export_aoti_only \ + --timeout 600 \ + --cleanup + + echo "SUCCESS: ExecutorTorch CUDA AOTI export ${cuda_version} tests completed successfully" +} + +# Main execution +echo "Current working directory: $(pwd)" +echo "Directory contents:" +ls -la + +# Run the CUDA AOTI export test +test_cuda_aoti_export "${CUDA_VERSION}" diff --git a/.ci/scripts/test_cuda_export_aoti.py b/.ci/scripts/test_cuda_export_aoti.py new file mode 100755 index 00000000000..3748dc5fe33 --- /dev/null +++ b/.ci/scripts/test_cuda_export_aoti.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test script for CUDA AOTI export functionality. +This script tests basic CUDA export functionality for a subset of models: +linear, conv2d, add, and resnet18. +""" + +import argparse +import os +import subprocess +import sys +from typing import List, Optional + + +def run_command( + cmd: List[str], cwd: Optional[str] = None, timeout: int = 300 +) -> subprocess.CompletedProcess: + """Run a command with proper error handling and timeout.""" + print(f"Running command: {' '.join(cmd)}") + if cwd: + print(f"Working directory: {cwd}") + + try: + result = subprocess.run( + cmd, + cwd=cwd, + capture_output=True, + text=True, + timeout=timeout, + check=False, # We'll handle the return code ourselves + ) + + if result.stdout: + print("STDOUT:") + print(result.stdout) + if result.stderr: + print("STDERR:") + print(result.stderr) + + return result + except subprocess.TimeoutExpired as e: + print(f"ERROR: Command timed out after {timeout} seconds") + raise e + except Exception as e: + print(f"ERROR: Failed to run command: {e}") + raise e + + +def test_cuda_export( + model_name: str, export_mode: str = "export_aoti_only", timeout: int = 300 +) -> bool: + """Test CUDA export for a specific model.""" + print(f"\n{'='*60}") + print(f"Testing CUDA export for model: {model_name}") + print(f"Export mode: {export_mode}") + print(f"{'='*60}") + + try: + # Run the export using export_aoti.py + cmd = ["python", "export_aoti.py", model_name] + if export_mode == "export_aoti_only": + cmd.append("--aoti_only") + + result = run_command(cmd, timeout=timeout) + + if result.returncode == 0: + print(f"SUCCESS: {model_name} export completed successfully") + return True + else: + print( + f"ERROR: {model_name} export failed with return code {result.returncode}" + ) + return False + + except subprocess.TimeoutExpired: + print(f"ERROR: {model_name} export timed out after {timeout} seconds") + return False + except Exception as e: + print(f"ERROR: {model_name} export failed with exception: {e}") + return False + + +def cleanup_temp_files(): + """Clean up temporary files generated during export.""" + print("Cleaning up temporary files...") + + # List of file patterns to clean up + cleanup_patterns = [ + "*.cubin", + "*.pte", + "*.so", + "*kernel_metadata.json", + "*kernel.cpp", + "*wrapper_metadata.json", + "*wrapper.cpp", + "*wrapper.json", + "aoti_intermediate_output.txt", + ] + + # Remove files matching patterns + for pattern in cleanup_patterns: + try: + import glob + + files = glob.glob(pattern) + for file in files: + if os.path.isfile(file): + os.remove(file) + print(f"Removed file: {file}") + except Exception as e: + print(f"Warning: Failed to remove {pattern}: {e}") + + # Remove temporary directories created by wrappers + try: + import glob + + for wrapper_file in glob.glob("*wrapper.cpp"): + basename = wrapper_file.replace("wrapper.cpp", "") + if os.path.isdir(basename): + import shutil + + shutil.rmtree(basename) + print(f"Removed directory: {basename}") + except Exception as e: + print(f"Warning: Failed to remove wrapper directories: {e}") + + print("Cleanup completed.") + + +def main(): + """Main function to test CUDA export for specified models.""" + parser = argparse.ArgumentParser( + description="Test CUDA AOTI export functionality", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--models", + nargs="+", + default=["linear", "conv2d", "add", "resnet18"], + help="List of models to test (default: linear, conv2d, add, resnet18)", + ) + + parser.add_argument( + "--export-mode", + choices=["export_aoti_only", "full"], + default="export_aoti_only", + help="Export mode: export_aoti_only (AOTI only) or full (full pipeline)", + ) + + parser.add_argument( + "--timeout", + type=int, + default=300, + help="Timeout for each model export in seconds (default: 300)", + ) + + parser.add_argument( + "--cleanup", + action="store_true", + default=True, + help="Clean up temporary files after testing (default: True)", + ) + + args = parser.parse_args() + + print("CUDA AOTI Export Test") + print("=" * 60) + print(f"Models to test: {args.models}") + print(f"Export mode: {args.export_mode}") + print(f"Timeout per model: {args.timeout} seconds") + print(f"Cleanup enabled: {args.cleanup}") + print("=" * 60) + + # Check if we're in the correct directory (should have export_aoti.py) + if not os.path.exists("export_aoti.py"): + print("ERROR: export_aoti.py not found in current directory") + print("Please run this script from the executorch root directory") + sys.exit(1) + + # Test each model + successful_models = [] + failed_models = [] + + for model in args.models: + # Clean up before each test + if args.cleanup: + cleanup_temp_files() + + success = test_cuda_export(model, args.export_mode, args.timeout) + + if success: + successful_models.append(model) + else: + failed_models.append(model) + + # Final cleanup + if args.cleanup: + cleanup_temp_files() + + # Print summary + print("\n" + "=" * 60) + print("CUDA AOTI Export Test Summary") + print("=" * 60) + print(f"Total models tested: {len(args.models)}") + print(f"Successful exports: {len(successful_models)}") + print(f"Failed exports: {len(failed_models)}") + + if successful_models: + print(f"Successful models: {', '.join(successful_models)}") + + if failed_models: + print(f"Failed models: {', '.join(failed_models)}") + print("\nERROR: One or more model exports failed!") + sys.exit(1) + else: + print("\nSUCCESS: All model exports completed successfully!") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/test-backend-cuda.yml b/.github/workflows/test-backend-cuda.yml new file mode 100644 index 00000000000..bc8063b5b73 --- /dev/null +++ b/.github/workflows/test-backend-cuda.yml @@ -0,0 +1,68 @@ +# Test ExecutorTorch CUDA AOTI Export Functionality +# This workflow tests whether ExecutorTorch can successfully export models using CUDA AOTI +# across different CUDA versions (12.6, 12.8, 12.9) for a subset of models: +# linear, conv2d, add, and resnet18 +# +# The test focuses on export-only functionality and verifies that no errors are raised +# during the AOTI export process. + +name: Test CUDA AOTI Export + +on: + pull_request: + push: + branches: + - main + - release/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: false + +jobs: + test-cuda-aoti-export: + strategy: + fail-fast: false + matrix: + cuda-version: ["12.6", "12.8", "12.9"] + + name: test-executorch-cuda-aoti-export-${{ matrix.cuda-version }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 120 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda-version }} + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + if [ -n "$CONDA_ENV" ]; then + conda activate "${CONDA_ENV}" + fi + + # Test ExecutorTorch CUDA AOTI export - ExecutorTorch will automatically detect CUDA version + # and install the appropriate PyTorch wheel when CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + PYTHON_EXECUTABLE=python bash .ci/scripts/test-cuda-export-aoti.sh "${{ matrix.cuda-version }}" + + # This job will fail if any of the CUDA AOTI export tests fail + check-all-cuda-aoti-exports: + needs: test-cuda-aoti-export + runs-on: ubuntu-latest + if: always() + steps: + - name: Check if all CUDA AOTI export tests succeeded + run: | + if [[ "${{ needs.test-cuda-aoti-export.result }}" != "success" ]]; then + echo "ERROR: One or more ExecutorTorch CUDA AOTI export tests failed!" + echo "CUDA AOTI export test results: ${{ needs.test-cuda-aoti-export.result }}" + exit 1 + else + echo "SUCCESS: All ExecutorTorch CUDA AOTI export tests (12.6, 12.8, 12.9) completed successfully!" + fi \ No newline at end of file diff --git a/.github/workflows/test-cuda-builds.yml b/.github/workflows/test-cuda-builds.yml new file mode 100644 index 00000000000..eff26e72c67 --- /dev/null +++ b/.github/workflows/test-cuda-builds.yml @@ -0,0 +1,68 @@ +# Test ExecutorTorch CUDA Build Compatibility +# This workflow tests whether ExecutorTorch can be successfully built with CUDA support +# across different CUDA versions (12.6, 12.8, 12.9) using the command: +# CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh +# +# Note: ExecutorTorch automatically detects the system CUDA version using nvcc and +# installs the appropriate PyTorch wheel. No manual CUDA/PyTorch installation needed. + +name: Test CUDA Builds + +on: + pull_request: + push: + branches: + - main + - release/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: false + +jobs: + test-cuda-builds: + strategy: + fail-fast: false + matrix: + cuda-version: ["12.6", "12.8", "12.9"] + + name: test-executorch-cuda-build-${{ matrix.cuda-version }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda-version }} + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + if [ -n "$CONDA_ENV" ]; then + conda activate "${CONDA_ENV}" + fi + + # Test ExecutorTorch CUDA build - ExecutorTorch will automatically detect CUDA version + # and install the appropriate PyTorch wheel when CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + PYTHON_EXECUTABLE=python bash .ci/scripts/test-cuda-build.sh "${{ matrix.cuda-version }}" + + # This job will fail if any of the CUDA versions fail + check-all-cuda-builds: + needs: test-cuda-builds + runs-on: ubuntu-latest + if: always() + steps: + - name: Check if all CUDA builds succeeded + run: | + if [[ "${{ needs.test-cuda-builds.result }}" != "success" ]]; then + echo "ERROR: One or more ExecutorTorch CUDA builds failed!" + echo "CUDA build results: ${{ needs.test-cuda-builds.result }}" + exit 1 + else + echo "SUCCESS: All ExecutorTorch CUDA builds (12.6, 12.8, 12.9) completed successfully!" + fi diff --git a/.gitignore b/.gitignore index b166f8c9512..295c352adbc 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,16 @@ tokenizer.json !test_bpe_tokenizer.bin !test_tiktoken_tokenizer.model +# AOTI temporary files +*.cubin +*kernel_metadata.json +*kernel.cpp +*wrapper_metadata.json +*wrapper.cpp +*wrapper.json + +aoti_debug_data* + # Editor temporaries *.idea *.sw[a-z] diff --git a/CMakeLists.txt b/CMakeLists.txt index fc427d517a9..586f1b1128f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_SKIP_BUILD_RPATH OFF) # Don't use the install-rpath during the build phase set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) + # Automatically add all linked folders that are NOT in the build directory to # the rpath (per library?) # @@ -587,6 +588,16 @@ endif() if(EXECUTORCH_BUILD_CORTEX_M) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m) + list(APPEND _executorch_backends coretex_m_backend) +endif() + +if(EXECUTORCH_BUILD_CUDA) + # Build common AOTI functionality (required for CUDA) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti) + # Build CUDA-specific AOTI functionality + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda) + # Add aoti_cuda to backends - it already depends on aoti_common + list(APPEND _executorch_backends aoti_cuda) endif() if(EXECUTORCH_BUILD_EXTENSION_APPLE) @@ -1006,6 +1017,11 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) extension_runner_util gflags executorch_backends ) + # Add flat tensor extension if it's built + if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR) + list(APPEND _executor_runner_libs extension_flat_tensor) + endif() + if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib) elseif(EXECUTORCH_BUILD_CADENCE) diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt new file mode 100644 index 00000000000..ab3ac80e57a --- /dev/null +++ b/backends/aoti/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# Common AOTI functionality (non-CUDA) +set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp utils.cpp) +add_library(aoti_common STATIC ${_aoti_common_sources}) +target_include_directories( + aoti_common + PUBLIC $ $ + # PyTorch AOTI headers from ExecutorTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) +target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC) +# Ensure symbols are exported properly +target_link_options(aoti_common PUBLIC -Wl,--export-dynamic) + +# Link against PyTorch libraries and standard libraries +target_link_libraries( + aoti_common + PUBLIC extension_tensor ${CMAKE_DL_LIBS} + # Link PyTorch libraries for AOTI functions + ${TORCH_LIBRARIES} +) +executorch_target_link_options_shared_lib(aoti_common) + +install( + TARGETS aoti_common + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/aoti/TARGETS b/backends/aoti/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/aoti/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp new file mode 100644 index 00000000000..f9d66ed82e4 --- /dev/null +++ b/backends/aoti/aoti_model_container.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "aoti_model_container.h" + +namespace executorch { +namespace backends { +namespace aoti { + +extern "C" { + +// Global function pointers for AOT Inductor model container operations +// These will be loaded dynamically from the shared library +AOTInductorModelContainerCreateWithDeviceFunc + AOTInductorModelContainerCreateWithDevice = nullptr; +AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr; +AOTInductorModelContainerGetNumInputsFunc + AOTInductorModelContainerGetNumInputs = nullptr; +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; +AOTInductorModelContainerGetNumOutputsFunc + AOTInductorModelContainerGetNumOutputs = nullptr; +AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h new file mode 100644 index 00000000000..e8bc253d9c0 --- /dev/null +++ b/backends/aoti/aoti_model_container.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Type definitions +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; + +// Forward declarations for AOT Inductor model container +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; +using AOTInductorStreamHandle = void*; +using AOTIProxyExecutorHandle = void*; + +// Function pointer types for AOT Inductor model container operations +using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir); + +using AOTInductorModelContainerDeleteFunc = + AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle); + +using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + AOTITensorHandle* input_handles, // array of input AOTITensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AOTITensorHandle* + output_handles, // array for writing output AOTITensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t n_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Global function pointers (will be loaded dynamically) +extern AOTInductorModelContainerCreateWithDeviceFunc + AOTInductorModelContainerCreateWithDevice; +extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete; +extern AOTInductorModelContainerGetNumInputsFunc + AOTInductorModelContainerGetNumInputs; +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; +extern AOTInductorModelContainerGetNumOutputsFunc + AOTInductorModelContainerGetNumOutputs; +extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; + +} // extern "C" + +// AOTI Delegate Handle structure +struct AOTIDelegateHandle { + void* so_handle; + AOTInductorModelContainerHandle container_handle; +}; + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp new file mode 100644 index 00000000000..97a0478ba52 --- /dev/null +++ b/backends/aoti/common_shims.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "common_shims.h" +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +namespace internal { +// Constants for file operations +const char* const TENSOR_OUTPUT_FILENAME = + "/home/gasoonjia/executorch/aoti_intermediate_output.txt"; +} // namespace internal + +// Global storage for tensor metadata +std::unordered_map> tensor_to_sizes; +std::unordered_map> tensor_to_strides; + +extern "C" { + +// Autograd mode functions +int32_t aoti_torch_grad_mode_is_enabled() { + // No autograd ever + return false; +} + +void aoti_torch_grad_mode_set_enabled(bool enabled) { + if (enabled) { + throw std::runtime_error("Cannot enable autograd"); + } +} + +// Tensor attribute operations +AOTITorchError aoti_torch_get_data_ptr( + AOTITensorHandle tensor, + void** ret_data_ptr) { + *ret_data_ptr = tensor->mutable_data_ptr(); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_storage_offset( + AOTITensorHandle tensor, + int64_t* ret_storage_offset) { + // Storage offset is always 0 in ET + *ret_storage_offset = 0; + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_strides( + AOTITensorHandle tensor, + int64_t** ret_strides) { + auto it = tensor_to_strides.find(tensor); + if (it == tensor_to_strides.end()) { + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (int i = 0; i < tensor->dim(); i++) { + strides[i] = tensor_strides[i]; + } + it = tensor_to_strides.emplace(tensor, std::move(strides)).first; + } + *ret_strides = it->second.data(); + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dtype( + AOTITensorHandle tensor, + int32_t* ret_dtype) { + *ret_dtype = static_cast(tensor->scalar_type()); + + return Error::Ok; +} + +AOTITorchError aoti_torch_get_sizes( + AOTITensorHandle tensor, + int64_t** ret_sizes) { + auto it = tensor_to_sizes.find(tensor); + if (it == tensor_to_sizes.end()) { + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (int i = 0; i < tensor->dim(); i++) { + sizes[i] = tensor_sizes[i]; + } + it = tensor_to_sizes.emplace(tensor, std::move(sizes)).first; + } + *ret_sizes = it->second.data(); + return Error::Ok; +} + +AOTITorchError aoti_torch_get_storage_size( + AOTITensorHandle tensor, + int64_t* ret_size) { + throw std::runtime_error("Cannot get storage size on ETensor"); +} + +AOTITorchError aoti_torch_get_device_index( + AOTITensorHandle tensor, + int32_t* ret_device_index) { + // Let's assume all tensors AOTI using are on CUDA:0 + *ret_device_index = 0; + return Error::Ok; +} + +AOTITorchError aoti_torch_get_dim(AOTITensorHandle tensor, int64_t* ret_dim) { + *ret_dim = static_cast(tensor->dim()); + return Error::Ok; +} + +// Device and layout utility functions +int32_t aoti_torch_device_type_cpu() { + // Let's say cpu is 0 for ET as well + return 0; +} + +__attribute__((__visibility__("default"))) int32_t aoti_torch_layout_strided() { + // ET only support strided layout, the return value will always be 0, a.k.a + // at::Layout::Strided; + return 0; +} + +// Dtype constants - these return the PyTorch dtype codes +// Currently only float32 is supported, but using robust enum-based approach +__attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_float32() { + return 6; // PyTorch's float32 dtype code +} + +// Cleanup functions +void cleanup_tensor_metadata() { + tensor_to_sizes.clear(); + tensor_to_strides.clear(); +} + +void cleanup_aoti_tensor_output() { + // Clean up any tensor output related resources + // For now this is a no-op, but can be extended if needed +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h new file mode 100644 index 00000000000..260a7661c6b --- /dev/null +++ b/backends/aoti/common_shims.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +// Global storage for tensor metadata +extern std::unordered_map> tensor_to_sizes; +extern std::unordered_map> tensor_to_strides; + +// Attribute-related operations (memory-irrelevant) +AOTITorchError aoti_torch_get_data_ptr( + AOTITensorHandle tensor, + void** ret_data_ptr); + +AOTITorchError aoti_torch_get_storage_offset( + AOTITensorHandle tensor, + int64_t* ret_storage_offset); + +AOTITorchError aoti_torch_get_strides( + AOTITensorHandle tensor, + int64_t** ret_strides); + +AOTITorchError aoti_torch_get_dtype( + AOTITensorHandle tensor, + int32_t* ret_dtype); + +AOTITorchError aoti_torch_get_sizes( + AOTITensorHandle tensor, + int64_t** ret_sizes); + +AOTITorchError aoti_torch_get_storage_size( + AOTITensorHandle tensor, + int64_t* ret_size); + +AOTITorchError aoti_torch_get_device_index( + AOTITensorHandle tensor, + int32_t* ret_device_index); + +AOTITorchError aoti_torch_get_dim(AOTITensorHandle tensor, int64_t* ret_dim); + +// Utility functions for device and layout information +int32_t aoti_torch_device_type_cpu(); +int32_t aoti_torch_layout_strided(); +int32_t aoti_torch_dtype_float32(); + +// Autograd mode functions +int32_t aoti_torch_grad_mode_is_enabled(); +void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Cleanup functions for clearing global state +void cleanup_tensor_metadata(); +void cleanup_aoti_tensor_output(); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl new file mode 100644 index 00000000000..bd46550d81e --- /dev/null +++ b/backends/aoti/targets.bzl @@ -0,0 +1,28 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + # Common AOTI functionality (non-CUDA) + runtime.cxx_library( + name = "aoti_common", + srcs = [ + "aoti_model_container.cpp", + "common_shims.cpp", + "utils.cpp", + ], + headers = [ + "aoti_model_container.h", + "common_shims.h", + "utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/backend:interface", + "//executorch/runtime/core:core", + "//caffe2/torch/csrc/inductor:aoti_torch", + ], + ) diff --git a/backends/aoti/utils.cpp b/backends/aoti/utils.cpp new file mode 100644 index 00000000000..68c28eed265 --- /dev/null +++ b/backends/aoti/utils.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "utils.h" +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +extern "C" { + +// Map int32_t dtype to number of bytes per element (reusing ExecutorTorch's +// elementSize function) +size_t dtype_to_element_size(int32_t dtype) { + // First convert int32_t dtype to ExecutorTorch ScalarType, then use existing + // elementSize function + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + if (scalar_type == executorch::aten::ScalarType::Undefined) { + ET_LOG(Error, "Unsupported dtype: %d for element size calculation", dtype); + return 0; // Return 0 to indicate error + } + + // Reuse ExecutorTorch's existing elementSize function from scalar_type_util.h + return executorch::runtime::elementSize(scalar_type); +} + +// Map int32_t dtype to ExecutorTorch ScalarType (robust version of hardcoded +// ScalarType::Float) +executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { + // Convert based on known PyTorch dtype codes (without CUDA-specific + // dependency) + switch (dtype) { + case 6: // PyTorch's float32 dtype code + return executorch::aten::ScalarType::Float; + // Future support for additional dtypes can be added here + // case 11: // PyTorch's bool dtype code + // return executorch::aten::ScalarType::Bool; + // case 1: // PyTorch's uint8 dtype code + // return executorch::aten::ScalarType::Byte; + // case 2: // PyTorch's int8 dtype code + // return executorch::aten::ScalarType::Char; + // case 3: // PyTorch's int16 dtype code + // return executorch::aten::ScalarType::Short; + // case 4: // PyTorch's int32 dtype code + // return executorch::aten::ScalarType::Int; + // case 5: // PyTorch's int64 dtype code + // return executorch::aten::ScalarType::Long; + // case 7: // PyTorch's float16 dtype code + // return executorch::aten::ScalarType::Half; + // case 8: // PyTorch's float64 dtype code + // return executorch::aten::ScalarType::Double; + // case 15: // PyTorch's bfloat16 dtype code + // return executorch::aten::ScalarType::BFloat16; + default: + ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype); + return executorch::aten::ScalarType::Undefined; + } +} + +// Storage offset validation utility function +AOTITorchError validate_storage_offset(int64_t storage_offset) { + // Storage offset must always be 0 + if (storage_offset != 0) { + ET_LOG( + Error, + "Storage offset must be 0. Got storage_offset: %ld", + storage_offset); + return Error::InvalidArgument; + } + return Error::Ok; +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch \ No newline at end of file diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h new file mode 100644 index 00000000000..828f15ee1a4 --- /dev/null +++ b/backends/aoti/utils.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; + +extern "C" { + +// Common AOTI type aliases +using AOTITorchError = Error; + +// Map int32_t dtype to number of bytes per element (reusing ExecutorTorch's +// elementSize function) +size_t dtype_to_element_size(int32_t dtype); + +// Map int32_t dtype to ExecutorTorch ScalarType (robust version of hardcoded +// ScalarType::Float) +executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype); + +// Storage offset validation utility function +AOTITorchError validate_storage_offset(int64_t storage_offset); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch \ No newline at end of file diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt new file mode 100644 index 00000000000..ef6a4ddb8bd --- /dev/null +++ b/backends/cuda/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Build AOTI CUDA backend for runtime. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +find_package(CUDAToolkit REQUIRED) + +# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +find_package_torch() + +# CUDA-specific AOTI functionality +set(_aoti_cuda_sources + runtime/cuda_backend.cpp + runtime/shims/memory.cpp + runtime/shims/tensor_attribute.cpp + runtime/utils.cpp) +add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) +target_include_directories( + aoti_cuda + PUBLIC + ${CUDAToolkit_INCLUDE_DIRS} + $ + $ + # PyTorch AOTI headers from ExecutorTorch's torch detection + ${TORCH_INCLUDE_DIRS} +) +target_compile_options(aoti_cuda PUBLIC -fexceptions -frtti -fPIC) +# Ensure symbols are exported properly +target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic) + +# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries +target_link_libraries( + aoti_cuda + PUBLIC + aoti_common + CUDA::cudart + ${CMAKE_DL_LIBS} + # Link PyTorch libraries for AOTI CUDA functions + ${TORCH_LIBRARIES} +) +# If you need other CUDA libraries, link them similarly: +# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) +executorch_target_link_options_shared_lib(aoti_cuda) + + +install( + TARGETS aoti_cuda + EXPORT ExecuTorchTargets + DESTINATION lib +) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS new file mode 100644 index 00000000000..77871de4469 --- /dev/null +++ b/backends/cuda/TARGETS @@ -0,0 +1,3 @@ +load("targets.bzl", "define_common_targets") + +define_common_targets() diff --git a/backends/cuda/__init__.py b/backends/cuda/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py new file mode 100644 index 00000000000..99599de6b6c --- /dev/null +++ b/backends/cuda/cuda_backend.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import os +import typing + +from subprocess import check_call +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = {} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +class CudaBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + + named_data_store = NamedDataStore() + + # Move the edge_program from CPU to CUDA for aoti compile + cuda_edge_program = move_to_device_pass(edge_program, "cuda") + + edge_program_module = cuda_edge_program.module() + args, kwargs = cuda_edge_program.example_inputs + + output_path = os.path.join(os.getcwd(), "aoti.so") + + options: dict[str, typing.Any] = { + "aot_inductor.embed_kernel_binary": True, + "aot_inductor.link_libtorch": False, + "aot_inductor.package_constants_in_so": True, + "aot_inductor.output_path": output_path, + "aot_inductor.debug_compile": True, + "aot_inductor.force_mmap_weights": False, + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "max_autotune_conv_backends": "TRITON", + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py new file mode 100644 index 00000000000..227d13ba093 --- /dev/null +++ b/backends/cuda/cuda_partitioner.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from torch.export.exported_program import ExportedProgram + + +@final +class CudaPartitioner(Partitioner): + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + tag = f"tag0" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip decomposing all ops and let the AOT compiler handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp new file mode 100644 index 00000000000..6cd20537e80 --- /dev/null +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -0,0 +1,375 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Include our shim layer headers +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +using namespace std; + +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::etensor::Tensor; + +class CudaBackend final : public ::executorch::runtime::BackendInterface { + public: + // Once in program + CudaBackend() { + ET_LOG(Info, "CudaBackend ctor"); + } + + bool is_available() const override { + return 1; + } + + // Once per loaded binary blob + Result init( + BackendInitContext& context, + FreeableBuffer* processed, // This will be a empty buffer + ArrayRef compile_specs // This will be my empty list + ) const override { + const NamedDataMap* named_data_map = context.get_named_data_map(); + + // std::string so_path = "/home/gasoonjia/executorch/aoti.so"; + + std::string so_path = "/tmp/test.so"; + std::string so_blob_key = "so_blob"; + + Result aoti_cuda_buffer = + named_data_map->get_data(so_blob_key.c_str()); + + // Create a temporary file + std::ofstream outfile(so_path.c_str(), std::ios::binary); + + // Write the ELF buffer to the temporary file + outfile.write( + (char*)aoti_cuda_buffer->data(), + sizeof(void*) * aoti_cuda_buffer->size()); + + // Finish writing the file to disk + outfile.close(); + + // Load the ELF using dlopen + void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (so_handle == nullptr) { + std::cout << dlerror() << std::endl; + return Error::AccessFailed; + } + + processed->Free(); + + AOTInductorModelContainerCreateWithDevice = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerCreateWithDevice")); + if (AOTInductorModelContainerCreateWithDevice == nullptr) { + perror("dlsym1"); + return Error::AccessFailed; + } + AOTInductorModelContainerDelete = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerDelete")); + if (AOTInductorModelContainerDelete == nullptr) { + perror("dlsym2"); + return Error::AccessFailed; + } + AOTInductorModelContainerGetNumInputs = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetNumInputs")); + if (AOTInductorModelContainerGetNumInputs == nullptr) { + perror("dlsym3"); + return Error::AccessFailed; + } + + AOTInductorModelContainerGetNumConstants = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetNumConstants")); + if (AOTInductorModelContainerGetNumConstants == nullptr) { + perror("dlsym AOTInductorModelContainerGetNumConstants"); + return Error::AccessFailed; + } + + AOTInductorModelContainerGetInputName = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetInputName")); + if (AOTInductorModelContainerGetInputName == nullptr) { + perror("dlsym AOTInductorModelContainerGetInputName"); + return Error::AccessFailed; + } + + AOTInductorModelContainerGetNumOutputs = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerGetNumOutputs")); + if (AOTInductorModelContainerGetNumOutputs == nullptr) { + perror("dlsym4"); + return Error::AccessFailed; + } + AOTInductorModelContainerRun = + reinterpret_cast( + dlsym(so_handle, "AOTInductorModelContainerRun")); + if (AOTInductorModelContainerRun == nullptr) { + perror("dlsym5"); + return Error::AccessFailed; + } + + AOTInductorModelContainerHandle container_handle = nullptr; + + AOTIRuntimeError err = AOTInductorModelContainerCreateWithDevice( + &container_handle, 1, "cuda", nullptr); + if (err != Error::Ok) { + return err; + } + printf("container_handle = %p\n", container_handle); + + AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + handle->so_handle = so_handle; + handle->container_handle = container_handle; + return (DelegateHandle*)handle; // Return the handle post-processing + } + + // Once per execution + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle_, + Span args) const override { + ET_LOG(Debug, "CudaBackend execute"); + + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + ET_LOG(Debug, "CudaBackend Handle generated"); + + size_t n_inputs; + AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs); + + size_t n_outputs; + AOTInductorModelContainerGetNumOutputs( + handle->container_handle, &n_outputs); + + ET_LOG(Debug, "CudaBackend n_outputs %zd generated", n_outputs); + + if (n_inputs + n_outputs != args.size()) { + ET_LOG( + Error, + "number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.", + n_inputs, + n_outputs, + args.size()); + return Error::InvalidArgument; + } + + ET_LOG( + Debug, + "number of user input %zd and output %zd generated from AOT Inductor matches ET runner's %zd.", + n_inputs, + n_outputs, + args.size()); + + // NOTE: ExecutorTorch tensors are always on CPU/host memory + // We need to create GPU copies for CUDA kernel execution + std::vector gpu_inputs( + n_inputs); // GPU copies for kernel execution + std::vector gpu_outputs( + n_outputs); // GPU tensors for kernel output + + ET_LOG(Debug, "CudaBackend input/output vectors generated"); + + // Process input tensors: ExecutorTorch provides CPU tensors, create GPU + // copies + for (int i = 0; i < n_inputs; i++) { + ET_LOG(Debug, "Processing input %d from args to inputs vector", i); + ET_LOG( + Debug, "is %d input a tensor input? %d", i, int(args[i]->isTensor())); + + // Get tensor dimensions and properties from ExecutorTorch CPU tensor + auto cpu_tensor = &(args[i]->toTensor()); + auto sizes = cpu_tensor->sizes(); + auto scalar_type = cpu_tensor->scalar_type(); + + // Create GPU tensor with same shape + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_input_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_input_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for input %d", i); + return Error::Internal; + } + + gpu_inputs[i] = gpu_input_handle; + + // Copy data from CPU to GPU + Error copy_err = aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "Failed to copy input %d from CPU to GPU", i); + return Error::Internal; + } + + ET_LOG(Debug, "Successfully copied input %d from CPU to GPU", i); + } + + ET_LOG(Debug, "CudaBackend GPU inputs generated"); + + // Process output tensors: create GPU counterparts for ExecutorTorch CPU + // tensors + for (int i = 0; i < n_outputs; i++) { + // Get output tensor dimensions from ExecutorTorch CPU tensor + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + auto sizes = cpu_output_tensor->sizes(); + auto scalar_type = cpu_output_tensor->scalar_type(); + + // Create GPU tensor with same shape for kernel output + std::vector sizes_vec(sizes.begin(), sizes.end()); + + AOTITensorHandle gpu_output_handle; + Error create_err = aoti_torch_empty_strided( + sizes_vec.size(), + sizes_vec.data(), + nullptr, // use default strides + static_cast(scalar_type), + 1, // device_type = cuda + 0, // device_index = 0 + &gpu_output_handle); + + if (create_err != Error::Ok) { + ET_LOG(Error, "Failed to create GPU tensor for output %d", i); + return Error::Internal; + } + + gpu_outputs[i] = gpu_output_handle; + ET_LOG(Debug, "Created GPU output tensor %d", i); + } + + ET_LOG(Debug, "CudaBackend output generated"); + + // Run AOTI container with GPU tensors + AOTIRuntimeError error = AOTInductorModelContainerRun( + handle->container_handle, + gpu_inputs.data(), // Use GPU input tensors + n_inputs, + gpu_outputs.data(), // Use GPU output tensors + n_outputs, + nullptr, // Pass the actual CUDA stream! + nullptr); // proxy_executor_handle can remain nullptr + + if (error != Error::Ok) { + ET_LOG( + Error, + "AOTInductorModelContainerRun failed with error code %d", + error); + return Error::Internal; + } + + ET_LOG(Debug, "CudaBackend running done"); + + // Copy GPU output results back to CPU output tensors + for (int i = 0; i < n_outputs; i++) { + auto cpu_output_tensor = &(args[i + n_inputs]->toTensor()); + Error copy_err = aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "Failed to copy GPU output %d back to CPU", i); + return Error::Internal; + } + ET_LOG(Debug, "Copied GPU output %d back to CPU", i); + } + + // Clean up GPU tensors that we created (ExecutorTorch tensors are always + // CPU, so all GPU tensors are our copies) + for (int i = 0; i < n_inputs; i++) { + // All GPU input tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_inputs[i]); + } + + for (int i = 0; i < n_outputs; i++) { + // All GPU output tensors were created by us, delete them + aoti_torch_delete_tensor_object(gpu_outputs[i]); + } + + ET_LOG(Debug, "CudaBackend execution completed successfully"); + + return Error::Ok; + } + + void destroy(DelegateHandle* handle_) const override { + AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + + // Delete the container BEFORE closing the shared library + if (handle->container_handle != nullptr) { + AOTIRuntimeError delete_result = + AOTInductorModelContainerDelete(handle->container_handle); + if (delete_result != Error::Ok) { + ET_LOG( + Error, + "AOTInductorModelContainerDelete failed with error code %d", + delete_result); + } + } + + // Now close the shared library + if (handle->so_handle != nullptr) { + dlclose(handle->so_handle); + } + + free(handle); + cleanup_memory(); + cleanup_tensor_metadata(); + ET_LOG(Debug, "CudaBackend handle %p destroy", handle_); + } +}; + +} // namespace aoti + +namespace { +auto cls = aoti::CudaBackend(); +executorch::runtime::Backend backend{"CudaBackend", &cls}; +static executorch::runtime::Error success_with_compiler = + register_backend(backend); +} // namespace + +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp new file mode 100644 index 00000000000..4518b359646 --- /dev/null +++ b/backends/cuda/runtime/shims/memory.cpp @@ -0,0 +1,672 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include // For posix_memalign +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +namespace { // Internal namespace for utility functions + +// Utility function to log array values as error msg in format [val1, val2, ...] +// For use with pointer-based arrays (e.g., int64_t* strides, int64_t* sizes) +void et_error_log_array_values( + const int64_t* values, + int64_t count, + const std::string& name = "values") { + if (count <= 0) { + ET_LOG(Error, "%s: empty array", name.c_str()); + return; + } + + // Build array string representation + std::string array_str = "["; + for (int64_t i = 0; i < count; i++) { + array_str += std::to_string(values[i]); + if (i < count - 1) { + array_str += ", "; + } + } + array_str += "]"; + + ET_LOG(Error, "%s: %s", name.c_str(), array_str.c_str()); +} + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +bool is_tensor_contiguous( + int64_t ndim, + const int64_t* sizes, + const int64_t* strides) { + int64_t expected_stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +// Check if tensor is in channels-last format (NHWC for 4D tensors) +// Channels-last format for 4D: strides = [H*W*C, 1, W*C, C] +bool is_tensor_channels_last( + int64_t ndim, + const int64_t* sizes, + const int64_t* strides) { + if (ndim != 4) { + return false; // Channels-last only defined for 4D tensors + } + + int64_t N = sizes[0], C = sizes[1], H = sizes[2], W = sizes[3]; + + // Check NHWC format: strides = [H*W*C, 1, W*C, C] + // Handle edge cases where dimensions might be 1 + return (strides[0] == H * W * C || N <= 1) && (strides[1] == 1 || C <= 1) && + (strides[2] == W * C || H <= 1) && (strides[3] == C || W <= 1); +} + +} // anonymous namespace + +// Global storage for tensors and their metadata +std::unordered_set> tensors; +std::unordered_map is_tensor_own_memory; + +extern "C" { + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + // Only float32 tensors are supported + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + // Storage offset must always be 0 + AOTITorchError storage_offset_error = validate_storage_offset(storage_offset); + if (storage_offset_error != Error::Ok) { + return storage_offset_error; + } + + // Convert sizes to the format expected by ExecutorTorch + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + + // check the tensor format + // Only support contiguous format for now + if (!is_tensor_contiguous(ndim, sizes_ptr, strides_ptr)) { + ET_LOG( + Error, + "aoti_torch_create_tensor_from_blob_v2 failed since input stride is not in contiguous format"); + return Error::InvalidArgument; + } + + // Create ExecutorTorch tensor that wraps the existing memory + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::make_tensor_ptr( + sizes, // tensor dimensions + data, // existing memory (don't copy!) + dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType + ); + + if (!tensor) { + ET_LOG(Error, "Failed to create tensor from blob"); + return Error::InvalidArgument; + } + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = false; + + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor) { + // This requires us to reserve CUDA memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + AOTITorchError dtype_error = validate_dtype(dtype); + if (dtype_error != Error::Ok) { + return dtype_error; + } + + size_t element_size = dtype_to_element_size(dtype); + if (element_size == 0) { + ET_LOG(Error, "Invalid element size for dtype: %d", dtype); + return Error::InvalidArgument; + } + int64_t nbytes = numel * element_size; + + if (device_type == 1) { // cuda + cudaError_t err = cudaMalloc(&ptr, nbytes); + if (err != cudaSuccess) { + ET_LOG( + Error, + "failed to allocate %ld bytes: %s", + nbytes, + cudaGetErrorString(err)); + return Error::MemoryAllocationFailed; + } + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match CUDA requirements + // do we need to do this in cuda backend? + int result = posix_memalign(&ptr, 16, nbytes); + if (result != 0) { + ET_LOG(Error, "Failed to allocate aligned CPU memory"); + return Error::MemoryAllocationFailed; + } + if (ptr == nullptr) { + ET_LOG(Error, "Failed to call posix_memalign"); + return Error::MemoryAllocationFailed; + } + } else { + ET_LOG( + Error, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + return Error::NotImplemented; + } + + // ETensor sizes + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = sizes_ptr[i]; + } + + // ETensor strides + std::vector strides(ndim); + if (strides_ptr != nullptr) { + // Use provided strides + for (int i = 0; i < ndim; i++) { + strides[i] = strides_ptr[i]; + } + } else { + // Calculate strides from sizes, assume it is in contiguous memory format + strides[ndim - 1] = 1; // Last dimension has stride 1 + for (int i = ndim - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes_ptr[i + 1]; + } + } + + // ETensor creation + auto tensor = executorch::extension::from_blob(ptr, sizes, strides); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = true; + + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { + // Check ownership before cleaning up metadata + auto ownership_it = is_tensor_own_memory.find(tensor); + bool owns_memory = (ownership_it != is_tensor_own_memory.end()) + ? ownership_it->second + : false; + + // Clean up ALL metadata maps immediately to prevent use-after-free + tensor_to_sizes.erase(tensor); + tensor_to_strides.erase(tensor); + is_tensor_own_memory.erase(tensor); + + if (!owns_memory) { + // Don't free memory since the tensor doesn't own it + return Error::Ok; + } + + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + // Get the tensor before erasing + auto tensor_ptr = *it; + + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Determine if it's GPU memory + cudaPointerAttributes attributes; + cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr); + + // et tensor does not own data; need to free them manually. + if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) { + // This is GPU memory - free with proper synchronization + cudaDeviceSynchronize(); // Wait for all operations to complete BEFORE + // freeing + cudaFree(data_ptr); + } else { + // This is CPU memory - free immediately + free(data_ptr); + } + // Remove from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + return Error::Ok; + } + } + ET_LOG(Error, "Didn't find tensor %p", tensor); + return Error::InvalidArgument; +} + +AOTITorchError checkCudaError(cudaError_t err, const char* msg) { + if (err != cudaSuccess) { + ET_LOG(Error, "%s (%s)", msg, cudaGetErrorString(err)); + return Error::Internal; + } + return Error::Ok; +} + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking) { + // assert same dim for now + if (self->dim() != src->dim()) { + ET_LOG( + Error, + "dimension mismatch. self.dim()=%d, src.dim()=%d", + self->dim(), + src->dim()); + return Error::InvalidArgument; + } + + // only support float32 for now + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + AOTITorchError self_dtype_error = validate_dtype(self_dtype); + if (self_dtype_error != Error::Ok) { + return self_dtype_error; + } + + AOTITorchError src_dtype_error = validate_dtype(src_dtype); + if (src_dtype_error != Error::Ok) { + return src_dtype_error; + } + + // Get stride information for layout validation + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Check if tensors have the same tensor schema (sizes, strides, dtype) + bool same_schema = true; + + // Check schema match + for (int i = 0; i < self->dim(); i++) { + if (self_sizes[i] != src_sizes[i] || self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + // Declare layout variables for both cases + bool self_is_contiguous = true; + bool src_is_contiguous = true; + bool self_is_channels_last = false; + bool src_is_channels_last = false; + + // For same schema, we don't need to check memory formats - just use direct + // copy + if (!same_schema) { + // Different strides: check memory format and only support contiguous <-> + // channels-last conversion + + // Check if contiguous (strides decrease from left to right) + self_is_contiguous = + is_tensor_contiguous(self->dim(), self_sizes, self_strides); + + src_is_contiguous = + is_tensor_contiguous(src->dim(), src_sizes, src_strides); + + // Check if channels-last (4D: NHWC format) + if (!self_is_contiguous) { + self_is_channels_last = + is_tensor_channels_last(self->dim(), self_sizes, self_strides); + } + + if (!src_is_contiguous) { + src_is_channels_last = + is_tensor_channels_last(src->dim(), src_sizes, src_strides); + } + + // Validate layout assumptions only when schemas differ + if (!self_is_contiguous && !self_is_channels_last) { + ET_LOG( + Error, + "self tensor must be contiguous or channels-last for stride conversion"); + et_error_log_array_values(self_strides, self->dim(), "self strides"); + et_error_log_array_values(self_sizes, self->dim(), "self_sizes"); + return Error::InvalidArgument; + } + + if (!src_is_contiguous && !src_is_channels_last) { + ET_LOG( + Error, + "src tensor must be contiguous or channels-last for stride conversion"); + et_error_log_array_values(src_strides, src->dim(), "self strides"); + et_error_log_array_values(src_sizes, src->dim(), "src_sizes"); + return Error::InvalidArgument; + } + } + + // Determine device locations + cudaPointerAttributes srcAttributes, dstAttributes; + cudaError_t err; + + err = cudaPointerGetAttributes(&srcAttributes, src->data_ptr()); + AOTITorchError cuda_err = + checkCudaError(err, "Failed to get source pointer attributes"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + + err = cudaPointerGetAttributes(&dstAttributes, self->data_ptr()); + cuda_err = + checkCudaError(err, "Failed to get destination pointer attributes"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + + bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice; + bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice; + + size_t total_bytes = src->nbytes(); + + if (same_schema) { + // Simple copy since layouts match + if (srcIsDevice && dstIsDevice) { + err = cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToDevice); + cuda_err = checkCudaError(err, "Failed to copy from device to device"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + } else if (srcIsDevice && !dstIsDevice) { + err = cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToHost); + cuda_err = checkCudaError(err, "Failed to copy from device to host"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + } else if (!srcIsDevice && dstIsDevice) { + err = cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyHostToDevice); + cuda_err = checkCudaError(err, "Failed to copy from host to device"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + } else { + std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes); + } + } else { + // Layout conversion needed (contiguous <-> channels-last) + + if (self->dim() != 4) { + ET_LOG(Error, "Layout conversion only supported for 4D tensors"); + return Error::NotImplemented; + } + + // Get data to host for processing + size_t total_elements = total_bytes / sizeof(float); + float* src_host_data = nullptr; + float* dst_host_data = nullptr; + bool need_free_src = false; + bool need_free_dst = false; + + if (srcIsDevice) { + src_host_data = new float[total_elements]; + err = cudaMemcpy( + src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost); + cuda_err = checkCudaError(err, "Failed to copy src to host"); + if (cuda_err != Error::Ok) { + delete[] src_host_data; + return cuda_err; + } + need_free_src = true; + } else { + src_host_data = static_cast(src->data_ptr()); + } + + if (dstIsDevice) { + dst_host_data = new float[total_elements]; + need_free_dst = true; + } else { + dst_host_data = static_cast(self->mutable_data_ptr()); + } + + // Perform layout conversion (4D NCHW <-> NHWC) + int64_t N = self_sizes[0], C = self_sizes[1], H = self_sizes[2], + W = self_sizes[3]; + + for (int64_t n = 0; n < N; n++) { + for (int64_t c = 0; c < C; c++) { + for (int64_t h = 0; h < H; h++) { + for (int64_t w = 0; w < W; w++) { + size_t src_offset, dst_offset; + + if (src_is_contiguous) { + // Source is NCHW + src_offset = n * C * H * W + c * H * W + h * W + w; + } else { + // Source is NHWC + src_offset = n * H * W * C + h * W * C + w * C + c; + } + + if (self_is_contiguous) { + // Destination is NCHW + dst_offset = n * C * H * W + c * H * W + h * W + w; + } else { + // Destination is NHWC + dst_offset = n * H * W * C + h * W * C + w * C + c; + } + + dst_host_data[dst_offset] = src_host_data[src_offset]; + } + } + } + } + + // Copy result back to device if needed + if (dstIsDevice) { + err = cudaMemcpy( + self->mutable_data_ptr(), + dst_host_data, + total_bytes, + cudaMemcpyHostToDevice); + cuda_err = checkCudaError(err, "Failed to copy result to device"); + if (cuda_err != Error::Ok) { + // Clean up temporary buffers before returning + if (need_free_src) + delete[] src_host_data; + if (need_free_dst) + delete[] dst_host_data; + return cuda_err; + } + } + + // Clean up temporary buffers + if (need_free_src) + delete[] src_host_data; + if (need_free_dst) + delete[] dst_host_data; + } + + // Verify the copy by checking first element + float src_first, dst_first; + if (srcIsDevice) { + err = cudaMemcpy( + &src_first, src->data_ptr(), sizeof(float), cudaMemcpyDeviceToHost); + cuda_err = checkCudaError(err, "Failed to copy first src element"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + } else { + src_first = static_cast(src->data_ptr())[0]; + } + + if (dstIsDevice) { + err = cudaMemcpy( + &dst_first, self->data_ptr(), sizeof(float), cudaMemcpyDeviceToHost); + cuda_err = checkCudaError(err, "Failed to copy first dst element"); + if (cuda_err != Error::Ok) { + return cuda_err; + } + } else { + dst_first = static_cast(self->data_ptr())[0]; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor) { + // Check if storage_offset is not 0 - return error if not + AOTITorchError storage_offset_error = validate_storage_offset(storage_offset); + if (storage_offset_error != Error::Ok) { + return storage_offset_error; + } + + // Check if dimensions match + if (self->dim() != ndim) { + ET_LOG( + Error, + "tensor dimension mismatch. self->dim(): %d, provided ndim: %ld", + self->dim(), + ndim); + return Error::InvalidArgument; + } + + // Get tensor properties from the input tensor + int32_t dtype; + AOTITorchError dtype_err = aoti_torch_get_dtype(self, &dtype); + if (dtype_err != Error::Ok) { + ET_LOG(Error, "failed to get dtype from input tensor"); + return dtype_err; + } + + int32_t device_type; + AOTITorchError device_type_err = + aoti_torch_get_device_type(self, &device_type); + if (device_type_err != Error::Ok) { + ET_LOG(Error, "failed to get device_type from input tensor"); + return device_type_err; + } + + int32_t device_index; + AOTITorchError device_index_err = + aoti_torch_get_device_index(self, &device_index); + if (device_index_err != Error::Ok) { + ET_LOG(Error, "failed to get device_index from input tensor"); + return device_index_err; + } + + // Create new tensor with the provided sizes and strides using + // aoti_torch_empty_strided + AOTITorchError create_err = aoti_torch_empty_strided( + ndim, + sizes_ptr, + strides_ptr, + dtype, + device_type, + device_index, + ret_new_tensor); + + if (create_err != Error::Ok) { + ET_LOG(Error, "failed to create new tensor with empty_strided"); + return create_err; + } + + // Copy data from source tensor to new tensor + AOTITorchError copy_err = aoti_torch_copy_(*ret_new_tensor, self, 0); + if (copy_err != Error::Ok) { + ET_LOG(Error, "failed to copy data from source tensor to new tensor"); + // Clean up the created tensor on failure + aoti_torch_delete_tensor_object(*ret_new_tensor); + *ret_new_tensor = nullptr; + return copy_err; + } + + return Error::Ok; +} + +// Cleanup function for clearing global state +void cleanup_memory() { + is_tensor_own_memory.clear(); + if (!tensors.empty()) { + ET_LOG(Error, "Warning: tensors not empty during cleanup"); + } +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h new file mode 100644 index 00000000000..41c03a1f552 --- /dev/null +++ b/backends/cuda/runtime/shims/memory.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +extern "C" { + +// Global storage declarations +extern std::unordered_map is_tensor_own_memory; +extern std::unordered_set> tensors; + +// Memory-related operations +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor); + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking); + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor); + +// Utility functions +AOTITorchError checkCudaError(cudaError_t err, const char* msg); +void cleanup_memory(); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/tensor_attribute.cpp b/backends/cuda/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..789c16d7555 --- /dev/null +++ b/backends/cuda/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace backends { +namespace aoti { + +extern "C" { + +// Device type functions for tensor attributes +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + // All tensors in aoti-cuda delegate are on CUDA + *ret_device_type = aoti_torch_device_type_cuda(); + return Error::Ok; +} + +// Device type constants +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_cuda() { + // Let's say cuda is 1 for ET as well + return 1; +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..d8866c19f24 --- /dev/null +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +using AOTITensorHandle = Tensor*; +using AOTITorchError = Error; + +// Device type functions for tensor attributes +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +// Device type constants +int32_t aoti_torch_device_type_cuda(); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch \ No newline at end of file diff --git a/backends/cuda/runtime/utils.cpp b/backends/cuda/runtime/utils.cpp new file mode 100644 index 00000000000..aee585f3a2e --- /dev/null +++ b/backends/cuda/runtime/utils.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "utils.h" +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Enum for supported data types in et-cuda backend +enum class SupportedDTypes : int32_t { + FLOAT32 = 6, // PyTorch's float32 dtype code + + // BOOL = 11, // PyTorch's bool dtype code + // UINT8 = 1, // PyTorch's uint8 dtype code + // INT8 = 2, // PyTorch's int8 dtype code + // INT16 = 3, // PyTorch's int16 dtype code + // INT32 = 4, // PyTorch's int32 dtype code + // INT64 = 5, // PyTorch's int64 dtype code + // FLOAT16 = 7, // PyTorch's float16 dtype code + // FLOAT64 = 8, // PyTorch's float64 dtype code + // BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in ET CUDA backend +bool is_dtype_supported_in_et_cuda(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::FLOAT32): + return true; + // case static_cast(SupportedDTypes::BOOL): + // case static_cast(SupportedDTypes::UINT8): + // case static_cast(SupportedDTypes::INT8): + // case static_cast(SupportedDTypes::INT16): + // case static_cast(SupportedDTypes::INT32): + // case static_cast(SupportedDTypes::INT64): + // case static_cast(SupportedDTypes::FLOAT16): + // case static_cast(SupportedDTypes::FLOAT64): + // case static_cast(SupportedDTypes::BFLOAT16): + // return true; + default: + return false; + } +} + +// Dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_cuda(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (float32)", + dtype, + static_cast(SupportedDTypes::FLOAT32)); + return Error::InvalidArgument; +} + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h new file mode 100644 index 00000000000..c941917577c --- /dev/null +++ b/backends/cuda/runtime/utils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; + +extern "C" { + +// Common AOTI type aliases +using AOTITorchError = Error; + +// Helper function to check if a dtype is supported in ET CUDA backend +bool is_dtype_supported_in_et_cuda(int32_t dtype); + +// Dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/targets.bzl b/backends/cuda/targets.bzl new file mode 100644 index 00000000000..be692cbb5a2 --- /dev/null +++ b/backends/cuda/targets.bzl @@ -0,0 +1,28 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + # CUDA-specific AOTI functionality + runtime.cxx_library( + name = "aoti_cuda", + srcs = [ + "runtime/cuda_backend.cpp", + "runtime/shims/memory.cpp", + "runtime/shims/tensor_attribute.cpp", + "runtime/utils.cpp", + ], + headers = [ + "runtime/shims/memory.h", + "runtime/shims/tensor_attribute.h", + "runtime/utils.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + # Constructor needed for backend registration. + compiler_flags = ["-Wno-global-constructors"], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/backends/aoti:aoti_common", + "//caffe2/torch/csrc/inductor:aoti_torch_cuda", + ], + ) diff --git a/compare_outputs.py b/compare_outputs.py new file mode 100755 index 00000000000..e83b701f73a --- /dev/null +++ b/compare_outputs.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Comparison script to calculate max absolute tolerance (atol) and max relative tolerance (rtol) +between runtime outputs and label outputs. +""" + +import os +import sys + +import numpy as np + + +def read_csv_file(filepath): + """Read a comma-separated values file and return as numpy array.""" + try: + with open(filepath, "r") as f: + content = f.read().strip() + if not content: + print(f"Warning: {filepath} is empty") + return np.array([]) + + # Split by comma and convert to float + values = [float(x.strip()) for x in content.split(",") if x.strip()] + return np.array(values) + except FileNotFoundError: + print(f"Error: {filepath} not found") + return None + except ValueError as e: + print(f"Error parsing {filepath}: {e}") + return None + + +def calculate_tolerances(runtime_outputs, label_outputs): + """Calculate max absolute and relative tolerances.""" + if runtime_outputs is None or label_outputs is None: + return None, None + + if len(runtime_outputs) == 0 or len(label_outputs) == 0: + print("Warning: One of the output arrays is empty") + return None, None + + if len(runtime_outputs) != len(label_outputs): + print( + f"Warning: Array lengths don't match: runtime={len(runtime_outputs)}, label={len(label_outputs)}" + ) + # Pad shorter array with zeros or truncate longer array + min_len = min(len(runtime_outputs), len(label_outputs)) + runtime_outputs = runtime_outputs[:min_len] + label_outputs = label_outputs[:min_len] + + # Calculate absolute differences + abs_diff = np.abs(runtime_outputs - label_outputs) + max_atol = np.max(abs_diff) + + # Calculate relative differences (avoid division by zero) + # rel_diff = |a - b| / max(|a|, |b|, eps) where eps is a small number + eps = 1e-8 + denominator = np.maximum( + np.maximum(np.abs(runtime_outputs), np.abs(label_outputs)), eps + ) + rel_diff = abs_diff / denominator + max_rtol = np.max(rel_diff) + + return max_atol, max_rtol + + +def main(): + """Main function to compare outputs and print tolerances.""" + # File paths + runtime_file = "aoti_debug_data/final_runtime_output.txt" + label_file = "aoti_debug_data/label_output.txt" + + print("=" * 60) + print("AOTI Runtime vs Label Output Comparison") + print("=" * 60) + + # Check if files exist + if not os.path.exists(runtime_file): + print(f"Error: {runtime_file} not found") + sys.exit(1) + + if not os.path.exists(label_file): + print(f"Error: {label_file} not found") + sys.exit(1) + + # Read the files + print(f"Reading runtime outputs from: {runtime_file}") + runtime_outputs = read_csv_file(runtime_file) + + print(f"Reading label outputs from: {label_file}") + label_outputs = read_csv_file(label_file) + + if runtime_outputs is None or label_outputs is None: + print("Failed to read one or both files") + sys.exit(1) + + print(f"Runtime outputs shape: {runtime_outputs.shape}") + print(f"Label outputs shape: {label_outputs.shape}") + + if runtime_outputs.shape != label_outputs.shape: + print("Error: Output shapes don't match") + sys.exit(1) + + # Calculate tolerances + max_atol, max_rtol = calculate_tolerances(runtime_outputs, label_outputs) + + if max_atol is None or max_rtol is None: + print("Failed to calculate tolerances") + sys.exit(1) + + # Print results + print("-" * 60) + print("COMPARISON RESULTS:") + print(f"Max Absolute Tolerance (atol): {max_atol:.10f}") + print(f"Max Relative Tolerance (rtol): {max_rtol:.10f}") + print("-" * 60) + + # Print some statistics + print("ADDITIONAL STATISTICS:") + print(f"Total elements compared: {len(runtime_outputs)}") + print( + f"Runtime output range: [{np.min(runtime_outputs):.6f}, {np.max(runtime_outputs):.6f}]" + ) + print( + f"Label output range: [{np.min(label_outputs):.6f}, {np.max(label_outputs):.6f}]" + ) + + # Calculate mean absolute difference + abs_diff = np.abs(runtime_outputs - label_outputs) + mean_atol = np.mean(abs_diff) + print(f"Mean Absolute Tolerance: {mean_atol:.10f}") + + # Check if outputs are close within common tolerances + is_close_1e5 = np.allclose( + runtime_outputs, + label_outputs, + atol=1e-5, + rtol=1e-5, + ) + is_close_1e6 = np.allclose( + runtime_outputs, + label_outputs, + atol=1e-6, + rtol=1e-6, + ) + + print(f"Close within atol=1e-5, rtol=1e-5: {is_close_1e5}") + print(f"Close within atol=1e-6, rtol=1e-6: {is_close_1e6}") + + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 5ce872eec8e..37029d150b8 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -1,7 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. * Copyright 2024-2025 Arm Limited and/or its affiliates. + * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -50,16 +51,10 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); -DEFINE_string(inputs, "", "Comma-separated list of input files"); DEFINE_string( - output_file, + data_path, "", - "Base name of output file. If not empty output will be written to the file(s)."); - -DEFINE_bool( - print_all_output, - false, - "Prints all output. By default only first and last 100 elements are printed."); + "Path to external tensor data file (.ptd format). Optional."); DEFINE_uint32(num_executions, 1, "Number of times to run the model."); #ifdef ET_EVENT_TRACER_ENABLED DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path."); @@ -69,9 +64,8 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); -using executorch::aten::ScalarType; -using executorch::aten::Tensor; using executorch::extension::FileDataLoader; +using executorch::extension::FlatTensorDataMap; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::EventTracer; @@ -83,8 +77,6 @@ using executorch::runtime::MethodMeta; using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; -using executorch::runtime::Tag; -using executorch::runtime::TensorInfo; /// Helper to manage resources for ETDump generation class EventTraceManager { @@ -171,43 +163,6 @@ int main(int argc, char** argv) { "FileDataLoader::from() failed: 0x%" PRIx32, (uint32_t)loader.error()); - std::vector inputs_storage; - std::vector> input_buffers; - - std::stringstream list_of_input_files(FLAGS_inputs); - std::string path; - - // First reserve memory for number of vector elements to avoid vector - // reallocations when emplacing back. - std::vector file_paths; - while (std::getline(list_of_input_files, path, ',')) { - file_paths.push_back(std::move(path)); - } - inputs_storage.reserve(file_paths.size()); - - for (const auto& file_path : file_paths) { - std::ifstream input_file_handle( - file_path, std::ios::binary | std::ios::ate); - - if (!input_file_handle) { - ET_LOG(Error, "Failed to open input file: %s\n", file_path.c_str()); - return 1; - } - - std::streamsize file_size = input_file_handle.tellg(); - input_file_handle.seekg(0, std::ios::beg); - - // Reserve memory for actual file contents. - inputs_storage.emplace_back(file_size, '\0'); - - if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) { - ET_LOG(Error, "Failed to read input file: %s\n", file_path.c_str()); - return 1; - } - - input_buffers.emplace_back(&inputs_storage.back()[0], file_size); - } - // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. Result program = Program::load(&loader.get()); @@ -293,8 +248,43 @@ int main(int argc, char** argv) { // be used by a single thread at at time, but it can be reused. // EventTraceManager tracer; + + // Handle optional external tensor data loading + std::unique_ptr data_loader; + std::unique_ptr data_map; + + if (!FLAGS_data_path.empty()) { + ET_LOG( + Info, "Loading external tensor data from %s", FLAGS_data_path.c_str()); + + // Create FileDataLoader for the PTD file + Result data_loader_result = + FileDataLoader::from(FLAGS_data_path.c_str()); + ET_CHECK_MSG( + data_loader_result.ok(), + "Failed to create FileDataLoader for data path %s: 0x%" PRIx32, + FLAGS_data_path.c_str(), + (uint32_t)data_loader_result.error()); + + data_loader = + std::make_unique(std::move(data_loader_result.get())); + + // Create FlatTensorDataMap from the loaded blob + Result data_map_result = + FlatTensorDataMap::load(data_loader.get()); + ET_CHECK_MSG( + data_map_result.ok(), + "Failed to load FlatTensorDataMap from %s: 0x%" PRIx32, + FLAGS_data_path.c_str(), + (uint32_t)data_map_result.error()); + + data_map = + std::make_unique(std::move(data_map_result.get())); + ET_LOG(Info, "External tensor data loaded successfully"); + } + Result method = program->load_method( - method_name, &memory_manager, tracer.get_event_tracer()); + method_name, &memory_manager, tracer.get_event_tracer(), data_map.get()); ET_CHECK_MSG( method.ok(), "Loading of method %s failed with status 0x%" PRIx32, @@ -306,8 +296,7 @@ int main(int argc, char** argv) { // Run the model. for (uint32_t i = 0; i < FLAGS_num_executions; i++) { ET_LOG(Debug, "Preparing inputs."); - // Allocate input tensors and set all of their elements to 1 or to the - // contents of input_buffers if available. The `inputs` + // Allocate input tensors and set all of their elements to 1. The `inputs` // variable owns the allocated memory and must live past the last call to // `execute()`. // @@ -315,8 +304,7 @@ int main(int argc, char** argv) { // because inputs whose space gets reused by memory planning (if // any such inputs exist) will not be preserved for the next // execution. - auto inputs = executorch::extension::prepare_input_tensors( - *method, {}, input_buffers); + auto inputs = executorch::extension::prepare_input_tensors(*method); ET_CHECK_MSG( inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, @@ -348,69 +336,47 @@ int main(int argc, char** argv) { std::vector outputs(method->outputs_size()); ET_LOG(Info, "%zu outputs: ", outputs.size()); Error status = method->get_outputs(outputs.data(), outputs.size()); + ET_CHECK(status == Error::Ok); - if (FLAGS_output_file.size() > 0) { - for (int i = 0; i < outputs.size(); ++i) { - if (outputs[i].isTensor()) { - Tensor tensor = outputs[i].toTensor(); - - char out_filename[255]; - snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i); - ET_LOG(Info, "Writing output to file: %s", out_filename); - FILE* out_file = fopen(out_filename, "wb"); - fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); - fclose(out_file); - } - } + // Open file to dump outputs + std::ofstream output_file("aoti_debug_data/final_runtime_output.txt"); + if (!output_file.is_open()) { + ET_LOG(Error, "Failed to open output file for dumping"); } - if (FLAGS_print_all_output) { - for (int i = 0; i < outputs.size(); ++i) { - if (outputs[i].isTensor()) { - Tensor tensor = outputs[i].toTensor(); - - for (int j = 0; j < tensor.numel(); ++j) { - if (tensor.scalar_type() == ScalarType::Int) { - printf( - "Output[%d][%d]: (int) %d\n", - i, - j, - tensor.const_data_ptr()[j]); - } else if (tensor.scalar_type() == ScalarType::Float) { - printf( - "Output[%d][%d]: (float) %f\n", - i, - j, - tensor.const_data_ptr()[j]); - } else if (tensor.scalar_type() == ScalarType::Char) { - printf( - "Output[%d][%d]: (char) %d\n", - i, - j, - tensor.const_data_ptr()[j]); - } else if (tensor.scalar_type() == ScalarType::Bool) { - printf( - "Output[%d][%d]: (bool) %s (0x%x)\n", - i, - j, - tensor.const_data_ptr()[j] ? "true " : "false", - tensor.const_data_ptr()[j]); - } - } - } else { - printf("Output[%d]: Not Tensor\n", i); + // Print the first and last 100 elements of long lists of scalars. + std::cout << executorch::extension::evalue_edge_items(100); + for (int i = 0; i < outputs.size(); ++i) { + std::cout << "Output " << i << ": " << outputs[i] << std::endl; + + // Also dump to file - extract tensor data and write comma-separated values + if (output_file.is_open() && outputs[i].isTensor()) { + auto tensor = outputs[i].toTensor(); + const void* data_ptr = tensor.const_data_ptr(); + + // assert output is in float different tensor types + const float* float_data = static_cast(data_ptr); + size_t num_elements = tensor.numel(); + + for (size_t j = 0; j < num_elements; ++j) { + if (j > 0) + output_file << ","; + output_file << float_data[j]; } - } - } else { - // Print the first and last 100 elements of long lists of scalars. - std::cout << executorch::extension::evalue_edge_items(100); - for (int i = 0; i < outputs.size(); ++i) { - std::cout << "OutputX " << i << ": " << outputs[i] << std::endl; + if (i < outputs.size() - 1) + output_file << ","; } } + if (output_file.is_open()) { + output_file.close(); + ET_LOG( + Info, + "Runtime outputs dumped to aoti_debug_data/final_runtime_output.txt"); + } + if (tracer.get_event_tracer()) { // Dump ETDump data containing profiling/debugging data to file specified in // command line flag. diff --git a/examples/portable/executor_runner/targets.bzl b/examples/portable/executor_runner/targets.bzl index 0af45d85075..d1304a84bcb 100644 --- a/examples/portable/executor_runner/targets.bzl +++ b/examples/portable/executor_runner/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(): "//executorch/devtools/etdump:etdump_flatcc", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/evalue_util:print_evalue", + "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/extension/runner_util:inputs", ], external_deps = [ @@ -38,6 +39,7 @@ def define_common_targets(): "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/evalue_util:print_evalue", + "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/extension/runner_util:inputs", "//executorch/extension/threadpool:cpuinfo_utils", "//executorch/extension/threadpool:threadpool", diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index dd8d97d66ac..d0225437c99 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -268,7 +268,9 @@ def _partition_and_lower_one_graph_module( """ Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module. """ - for tag, delegation_spec in partition_result.partition_tags.items(): + for idx, (tag, delegation_spec) in enumerate( + partition_result.partition_tags.items() + ): # Create partition with nodes containing this tag. There should only be # one contained submodule per tag node_list = _get_node_list_with_same_tag( @@ -311,6 +313,7 @@ def _partition_and_lower_one_graph_module( tag, call_module_node, is_submodule, + idx == 0, ) lowered_submodule = to_backend( @@ -452,7 +455,9 @@ def _create_partitions_in_graph_module( is_submodule: bool, ) -> Dict[str, List[torch.fx.Node]]: backend_id_to_submodule_name = {} - for tag, delegation_spec in partition_result.partition_tags.items(): + for idx, (tag, delegation_spec) in enumerate( + partition_result.partition_tags.items() + ): # Create partition with nodes containing this tag. There should only be # one contained submodule per tag node_list = _get_node_list_with_same_tag( @@ -492,6 +497,7 @@ def _create_partitions_in_graph_module( tag, call_module_node, is_submodule, + idx == 0, ) call_module_node.meta["backend_id"] = delegation_spec.backend_id call_module_node.meta["compile_spec"] = delegation_spec.compile_specs @@ -720,6 +726,8 @@ def to_backend( fake_edge_program = copy.deepcopy(edge_program) partitioner_result = partitioner_instance(fake_edge_program) tagged_exported_program = partitioner_result.tagged_exported_program + tagged_exported_program.example_inputs = edge_program.example_inputs + method_to_tagged_exported_program[method_name] = tagged_exported_program # Check that the partitioner did not modify the original graph diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 0618871bd40..eb84d508c2c 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -176,6 +176,7 @@ def emit_program( ) emitter.run() + plans.append(emitter.plan()) debug_handle_map[name] = emitter.debug_handle_map diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 61414990703..3c5ee5d36b0 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -682,6 +682,7 @@ def create_exported_program_from_submodule( tag: str, call_module_node: torch.fx.Node, is_submodule: bool, + is_first_partition: bool = False, ) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: """ Creates an ExportedProgram from the given submodule using the parameters and buffers @@ -720,6 +721,11 @@ def create_exported_program_from_submodule( in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] + # only the example inputs of first parition equals to the example inputs of the owning program + submodule_exmaple_inputs = ( + owning_program.example_inputs if is_first_partition else None + ) + return ( ExportedProgram( root=submodule, @@ -735,6 +741,7 @@ def create_exported_program_from_submodule( ), ) ], + example_inputs=submodule_exmaple_inputs, constants=subgraph_constants, verifiers=[owning_program.verifier], ), diff --git a/exir/program/_program.py b/exir/program/_program.py index a33d715ca3b..c740bbcb7b3 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1697,6 +1697,7 @@ def to_executorch( # noqa (FLAKE8) C901 after it has been transformed to the ExecuTorch backend. """ config = config if config else ExecutorchBackendConfig() + execution_programs: Dict[str, ExportedProgram] = {} for name, program in self._edge_programs.items(): if config.do_quant_fusion_and_const_prop: diff --git a/export_and_run_aoti.sh b/export_and_run_aoti.sh new file mode 100644 index 00000000000..dd4aeef1017 --- /dev/null +++ b/export_and_run_aoti.sh @@ -0,0 +1,269 @@ +#!/bin/bash + +# Script to export and run AOTI with different modes +# Usage: +# ./export_and_run_aoti.sh [mode] +# ./export_and_run_aoti.sh --mode= [--debug] [--dump] +# +# Examples: +# ./export_and_run_aoti.sh conv2d # Uses default mode (reinstall_all) +# ./export_and_run_aoti.sh conv2d inference # Uses inference mode +# ./export_and_run_aoti.sh conv2d --mode=inference # Alternative syntax +# ./export_and_run_aoti.sh conv2d --mode=inference --dump # With AOTI intermediate output dumping +# ./export_and_run_aoti.sh conv2d --mode=inference --debug --dump # With both debug and dump +# +# Available modes: reinstall_all (default), reinstall_aot, reinstall_runtime, inference, export_aoti_only +# Flags: +# --debug: Enable debug mode with extensive logging +# --dump: Enable AOTI intermediate output dumping to aoti_intermediate_output.txt +# model_arg: argument to pass to export_aoti.py + +set -e # Exit on any error + +# Parse command line arguments +MODE="reinstall_all" +MODEL_ARG="$1" +DEBUG_MODE=false +DUMP_MODE=false + +# Parse arguments for mode and debug flag +for arg in "$@"; do + case $arg in + --mode=*) + MODE="${arg#*=}" + shift + ;; + --debug) + DEBUG_MODE=true + shift + ;; + --dump) + DUMP_MODE=true + shift + ;; + reinstall_all|reinstall_aot|reinstall_runtime|inference|export_aoti_only) + # If it's the second argument and a valid mode, use it as mode + if [[ "$arg" == "$2" ]]; then + MODE="$arg" + fi + ;; + esac +done + +# Validate mode +case "$MODE" in + reinstall_all|reinstall_aot|reinstall_runtime|inference|export_aoti_only) + # Valid mode, continue + ;; + *) + echo "Error: Unknown mode '$MODE'" + echo "Available modes: reinstall_all, reinstall_aot, reinstall_runtime, inference, export_aoti_only" + echo "" + echo "Usage examples:" + echo " ./export_and_run_aoti.sh conv2d # Uses default mode" + echo " ./export_and_run_aoti.sh conv2d inference # Positional mode" + echo " ./export_and_run_aoti.sh conv2d --mode=inference # GNU-style mode" + echo " ./export_and_run_aoti.sh conv2d export_aoti_only # Export AOTI only (no runtime)" + echo " ./export_and_run_aoti.sh conv2d --mode=inference --debug # With debug options enabled" + exit 1 + ;; +esac + +echo "Running in mode: $MODE" +if [[ -n "$MODEL_ARG" ]]; then + echo "Model argument: $MODEL_ARG" +fi + +# Cleanup function to remove temporary files and directories +cleanup_temp_files() { + echo "Cleaning up temporary files and directories..." + + # Remove temporary directories + for file in *wrapper.cpp; do + if [[ -f "$file" ]]; then + basename="${file%wrapper.cpp}" + if [[ -d "$basename" ]]; then + echo "Removing directory: $basename" + rm -rf "$basename" + fi + fi + done + + # Remove temporary files with specific extensions + rm -f *.cubin + rm -f *.pte + rm -f *.so + rm -f *kernel_metadata.json + rm -f *kernel.cpp + rm -f *wrapper_metadata.json + rm -f *wrapper.cpp + rm -f *wrapper.json + rm -f aoti_intermediate_output.txt + + echo "Cleanup completed." +} + +# Run cleanup at the start +cleanup_temp_files + +# Function definitions for each step +install_executorch() { + echo "Installing executorch..." + ./install_executorch.sh +} + +export_aoti_model() { + local use_aoti_only=$1 + echo "Exporting AOTI model..." + if [[ "$use_aoti_only" == "--aoti_only" ]]; then + python export_aoti.py $MODEL_ARG --aoti_only + else + python export_aoti.py $MODEL_ARG + fi +} + +clean_install_executorch() { + echo "Clean installing executorch..." + ./install_executorch.sh --clean +} + +build_runtime() { + echo "Building runtime..." + # Clean the build directory to ensure debug flags take effect + rm -rf cmake-out + mkdir -p cmake-out + cd cmake-out + + if [[ "$DEBUG_MODE" == true ]]; then + echo "Building with debug configuration..." + cmake -DEXECUTORCH_BUILD_CUDA=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_LOG_LEVEL=Debug \ + -DCMAKE_BUILD_TYPE=Debug \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + .. + else + echo "Building with release configuration..." + cmake -DEXECUTORCH_BUILD_CUDA=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_LOG_LEVEL=Info \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + .. + fi + + cd .. + cmake --build cmake-out -j9 +} + +run_inference() { + echo "Running executor_runner with debug logging enabled..." + ./cmake-out/executor_runner --model_path aoti_model.pte --data_path aoti_cuda_blob.ptd +} + +compare_outputs() { + echo "Comparing runtime outputs with label outputs..." + python compare_outputs.py +} + +# Set up environment variables based on debug and dump flags +if [[ "$DEBUG_MODE" == true ]]; then + echo "Setting debug environment variables..." + export AOT_INDUCTOR_DEBUG_COMPILE="1" + export AOTINDUCTOR_REPRO_LEVEL=3 + + # Set intermediate value printer based on dump flag + if [[ "$DUMP_MODE" == true ]]; then + export AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER="2" + export INDUCTOR_PROVENANCE=1 + export TORCH_TRACE="/home/gasoonjia/executorch/aoti_debug_data" + echo "AOTI intermediate output dumping enabled (AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2)" + echo "Eager-AOTI relationship extration enabled (INDUCTOR_PROVENANCE=1), output to $TORCH_TRACE" + else + export AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER="3" + fi + + echo "Debug variables set:" + echo " AOT_INDUCTOR_DEBUG_COMPILE=$AOT_INDUCTOR_DEBUG_COMPILE" + echo " AOTINDUCTOR_REPRO_LEVEL=$AOTINDUCTOR_REPRO_LEVEL" + echo " AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=$AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER" +elif [[ "$DUMP_MODE" == true ]]; then + # Only dump mode enabled (without debug) + echo "Setting AOTI intermediate output dumping..." + export AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER="2" + export INDUCTOR_PROVENANCE=1 + export TORCH_TRACE="/home/gasoonjia/executorch/aoti_debug_data" + echo "AOTI intermediate output dumping enabled (AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2)" + echo " AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=$AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER" + echo "Eager-AOTI relationship extration enabled (INDUCTOR_PROVENANCE=1), output to $TORCH_TRACE" +else + # Ensure debug variables are unset for non-debug/non-dump modes + unset AOT_INDUCTOR_DEBUG_COMPILE + unset AOTINDUCTOR_REPRO_LEVEL + unset AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER + unset INDUCTOR_PROVENANCE + unset TORCH_TRACE +fi + +# Execute based on mode +case "$MODE" in + "reinstall_all") + echo "Mode: $MODE - Full reinstall and run" + if [[ "$DEBUG_MODE" == true ]]; then + echo "Debug options enabled with AOT Inductor debug settings" + fi + install_executorch + export_aoti_model + clean_install_executorch + build_runtime + run_inference + compare_outputs + ;; + "reinstall_aot") + echo "Mode: reinstall_aot - Reinstall AOT components and run e2e" + if [[ "$DEBUG_MODE" == true ]]; then + echo "Debug options enabled with AOT Inductor debug settings" + fi + install_executorch + export_aoti_model + run_inference + compare_outputs + ;; + "reinstall_runtime") + echo "Mode: reinstall_runtime - Rebuild runtime and run e2e" + if [[ "$DEBUG_MODE" == true ]]; then + echo "Debug options enabled with AOT Inductor debug settings" + fi + export_aoti_model + build_runtime + run_inference + compare_outputs + ;; + "inference") + echo "Mode: inference - Export model and run inference only" + if [[ "$DEBUG_MODE" == true ]]; then + echo "Debug options enabled with AOT Inductor debug settings" + fi + export_aoti_model + run_inference + compare_outputs + ;; + "export_aoti_only") + echo "Mode: export_aoti_only - Export model using pure AOTI only (no runtime or installation)" + if [[ "$DEBUG_MODE" == true ]]; then + echo "Debug options enabled with AOT Inductor debug settings" + fi + export_aoti_model "--aoti_only" + ;; + *) + echo "Error: Unknown mode '$MODE'" + echo "Available modes: reinstall_all, reinstall_aot, reinstall_runtime, inference, export_aoti_only" + exit 1 + ;; +esac + +echo "Script completed successfully!" diff --git a/export_aoti.py b/export_aoti.py new file mode 100644 index 00000000000..d0bf916f387 --- /dev/null +++ b/export_aoti.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python3 +""" +Unified export script for AOTI backend. +Usage: + python export_aoti.py # Uses export_model_to_et_aoti + python export_aoti.py --aoti_only # Uses export_model_to_pure_aoti + +Supported models: +- mv2: MobileNetV2 model +- linear: Simple linear layer model +- conv2d: Single Conv2d layer model +- add: Simple tensor addition model +""" + +import argparse +import copy +import os + +import shutil + +import sys +from subprocess import check_call +from typing import Any, Dict, Tuple + +import torch +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + +# from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge, to_edge_transform_and_lower +from torch import nn +from torch.export import export +from torch.nn.attention import SDPBackend +from torchvision import models +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from torchvision.models.resnet import ResNet18_Weights +from transformers import AutoModelForCausalLM, WhisperModel + + +# for maintaing precision of 32-bit float as much as possible +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +torch.backends.cudnn.conv.fp32_precision = "fp32" + + +# Model classes +class MV2(torch.nn.Module): + def __init__(self): + super(MV2, self).__init__() + self.mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights) + + def forward(self, x: torch.Tensor): + return self.mv2(x) + + +class ResNet18(torch.nn.Module): + def __init__(self): + super(ResNet18, self).__init__() + self.resnet18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + + def forward(self, x: torch.Tensor): + return self.resnet18(x) + + +class Linear(torch.nn.Module): + def __init__(self): + super(Linear, self).__init__() + self.linear = nn.Linear(7, 101) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +class SingleConv2d(nn.Module): + def __init__(self): + super(SingleConv2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=3, out_channels=5, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor): + return self.conv(x) + + +class Add(torch.nn.Module): + def __init__(self): + super(Add, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +class DepthwiseConv(nn.Module): + def __init__(self): + super().__init__() + # 32 input channels, 32 output channels, groups=32 for depthwise + self.conv = nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=32, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + + +class BatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.bn = nn.BatchNorm2d(num_features=16) + + def forward(self, x): + return self.bn(x) + + +class SingleResNetBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, stride=1): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channels) + + # Skip connection - identity mapping if same channels, 1x1 conv if different + self.skip_connection = None + if stride != 1 or in_channels != out_channels: + self.skip_connection = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + identity = x + + # First conv block + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second conv block + out = self.conv2(out) + out = self.bn2(out) + + # Skip connection + if self.skip_connection is not None: + identity = self.skip_connection(x) + + out += identity + out = self.relu(out) + + return out + + +class Llama31(torch.nn.Module): + def __init__(self, model_id="meta-llama/Meta-Llama-3.1-8B", use_cache=False): + super(Llama31, self).__init__() + # Load Llama 3.1 model from HF + self.use_cache = use_cache + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + device_map="cuda", + use_cache=self.use_cache, # Turn off KV cache + ) + self.model.eval() + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): + # Disable KV cache for inference + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=self.use_cache, # Explicitly turn off KV cache + ) + return outputs.logits + + +class Whisper(torch.nn.Module): + def __init__(self, model_name="openai/whisper-tiny"): + super(Whisper, self).__init__() + # 1. Load pre-trained Whisper model (tiny version is lightweight) + self.model = WhisperModel.from_pretrained(model_name) + self.model.eval() + + def forward(self, input_features: torch.Tensor): + outputs = self.model.encoder(input_features=input_features) + + # Return both encoder and decoder hidden states for compatibility + return outputs.last_hidden_state + + +class MockConv1d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + in_channels=80, + out_channels=384, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + +class TransformerBlock(nn.Module): + def __init__(self, embed_dim=256, num_heads=8, ff_dim=1024, dropout=0.1): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + + # Multi-head self-attention + self.self_attn = nn.MultiheadAttention( + embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True + ) + + # Layer normalization layers + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + + # Feed-forward network + self.ffn = nn.Sequential( + nn.Linear(embed_dim, ff_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + # Self-attention block with residual connection + attn_output, _ = self.self_attn(x, x, x) + x = self.norm1(x + attn_output) + + # Feed-forward block with residual connection + ff_output = self.ffn(x) + x = self.norm2(x + ff_output) + + return x + + +# Model registry mapping model names to their configurations +MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { + "mv2": { + "model_class": MV2, + "input_shapes": [(1, 3, 224, 224)], + "description": "MobileNetV2 model", + }, + "resnet18": { + "model_class": ResNet18, + "input_shapes": [(1, 3, 224, 224)], + "description": "ResNet18 model", + }, + "linear": { + "model_class": Linear, + "input_shapes": [(127, 7)], + "description": "Simple linear layer model", + }, + "conv2d": { + "model_class": SingleConv2d, + "input_shapes": [(4, 3, 8, 8)], + "description": "Single Conv2d layer model", + }, + "depthwise_conv": { + "model_class": DepthwiseConv, + "input_shapes": [(1, 32, 112, 112)], + "description": "Single Depthwise Conv2d layer model", + }, + "add": { + "model_class": Add, + "input_shapes": [(10,), (10,)], + "description": "Simple tensor addition model", + }, + "batchnorm": { + "model_class": BatchNorm, + "input_shapes": [(1, 16, 32, 32)], + "description": "Single BatchNorm2d layer model", + }, + "single_resnet_block": { + "model_class": SingleResNetBlock, + "input_shapes": [(1, 64, 8, 8)], + "description": "Single ResNet block with skip connection", + }, + "llama31": { + "model_class": Llama31, + "input_shapes": [(1, 32)], # batch_size=1, sequence_length=128 + "description": "Llama 3.1 model with KV cache disabled", + }, + "whisper": { + "model_class": Whisper, + "input_shapes": [(1, 80, 3000)], + "description": "OpenAI Whisper ASR model. now is encoder only", + }, + "conv1d": { + "model_class": MockConv1d, + "input_shapes": [(1, 80, 3000)], + "description": "Conv1d layer with 80 input channels, 384 output channels", + }, + "transformer_block": { + "model_class": TransformerBlock, + "input_shapes": [(4, 32, 256)], # batch_size=4, seq_len=32, embed_dim=256 + "description": "Single transformer block with multi-head attention and feed-forward network", + }, +} + + +def get_model_and_inputs( + model_name: str, +) -> Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]]: + """Get model and example inputs based on model name.""" + # + if model_name not in MODEL_REGISTRY: + available_models = ", ".join(MODEL_REGISTRY.keys()) + raise ValueError( + f"Unsupported model: {model_name}. Available models: {available_models}" + ) + + model_config = MODEL_REGISTRY[model_name] + model_class = model_config["model_class"] + input_shapes = model_config["input_shapes"] + device = "cpu" + + # Create model instance + model = model_class().to(device).eval() + + # Create example inputs (support multiple inputs) + example_inputs = tuple( + ( + torch.randint(0, 10000, size=shape, device=device) + if model_name == "llama31" + else torch.randn(*shape, device=device) + ) + for shape in input_shapes + ) + + return model, example_inputs + + +def export_model_to_et_aoti( + model, example_inputs, output_pte_path="aoti_model.pte", output_data_dir=None +): + """Export model through the AOTI pipeline.""" + all_one_input = tuple( + torch.ones_like(example_input) for example_input in example_inputs + ) + + label_output = model(*all_one_input) + print("label", label_output) + + # Create directory if it doesn't exist + os.makedirs("aoti_debug_data", exist_ok=True) + + # Dump label to file + with open("aoti_debug_data/label_output.txt", "w") as f: + if isinstance(label_output, tuple): + # Multiple outputs + all_elements = [] + for tensor in label_output: + if tensor.numel() > 0: + all_elements.extend(tensor.flatten().tolist()) + f.write(",".join(map(str, all_elements))) + else: + # Single output + if label_output.numel() > 0: + f.write(",".join(map(str, label_output.flatten().tolist()))) + + print(f"Starting export process...") + + print("Step 1: Converting to ATen dialect...") + with torch.nn.attention.sdpa_kernel( + [SDPBackend.MATH] # pyre-fixme[16] + ), torch.no_grad(): + # 1. torch.export: Defines the program with the ATen operator set. + aten_dialect = export(model, example_inputs, strict=False) + + # print(aten_dialect) + # exit(0) + + # 2. to_edge: Make optimizations for Edge devices + # aoti part should be decomposed by the internal torch._inductor.aot_compile + # we should preserve the lowerable part and waiting for aoti backend handle that + # Q: maybe need to turn on fallback_random? + + edge_program = to_edge_transform_and_lower( + aten_dialect, partitioner=[CudaPartitioner([])] + ) + + # edge_program = to_edge(aten_dialect) + + print(edge_program.exported_program()) + + # 3. to_executorch: Convert the graph to an ExecuTorch program + print("Step 4: Converting to ExecuTorch program...") + executorch_program = edge_program.to_executorch() + print("To executorch done.") + + # 4. Save the compiled .pte program + if output_data_dir is None: + output_data_dir = os.getcwd() + + print(f"Step 5: Saving pte to {output_pte_path} and ptd to {output_data_dir}") + with open(output_pte_path, "wb") as file: + file.write(executorch_program.buffer) + + print(f"size of Named Data: {len(executorch_program._tensor_data)}") + + executorch_program.write_tensor_data_to_file(output_data_dir) + + print( + f"Export completed successfully! PTE saved to {output_pte_path} and ptd saved to {output_data_dir}" + ) + + +def export_model_to_pure_aoti(model, example_inputs): + """Export model through the AOTI pipeline.""" + all_one_input = tuple( + torch.ones_like(example_input) for example_input in example_inputs + ) + + print("label", model(*all_one_input)) + + print(f"Starting export process...") + + # 1. torch.export: Defines the program with the ATen operator set. + print("Step 1: Converting to ATen dialect...") + aten_dialect = export(model, example_inputs) + + # 2. torch._inductor.aot_compile to aoti delegate + aten_dialect_module = aten_dialect.module() + + output_path = os.path.join(os.getcwd(), "aoti.so") + + options: dict[str, Any] = { + "aot_inductor.package_constants_in_so": True, + "aot_inductor.output_path": output_path, + "aot_inductor.debug_compile": True, + "aot_inductor.repro_level": 3, + "aot_inductor.debug_intermediate_value_printer": "2", + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "max_autotune_conv_backends": "TRITON", + } + + so_path = torch._inductor.aot_compile(aten_dialect_module, example_inputs, options=options) # type: ignore[arg-type] + + assert so_path == output_path, f"Expected {output_path} but got {so_path}" + + check_call( + f"patchelf --remove-needed libtorch.so --remove-needed libc10.so --remove-needed libtorch_cuda.so --remove-needed libc10_cuda.so --remove-needed libtorch_cpu.so --add-needed libcudart.so {output_path}", + shell=True, + ) + + +def main(): + # Set up argument parser + parser = argparse.ArgumentParser( + description="Unified export script for AOTI backend", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Add model name as positional argument + parser.add_argument( + "model_name", + help="Name of the model to export", + choices=list(MODEL_REGISTRY.keys()), + metavar="model_name", + ) + + # Add the --aoti_only flag + parser.add_argument( + "--aoti_only", + action="store_true", + help="Use export_model_to_pure_aoti instead of export_model_to_et_aoti", + ) + + # Parse arguments + args = parser.parse_args() + + # Show available models and descriptions in help + if len(sys.argv) == 1: + parser.print_help() + print(f"\nAvailable models: {', '.join(MODEL_REGISTRY.keys())}") + print("\nModel descriptions:") + for name, config in MODEL_REGISTRY.items(): + print(f" {name}: {config['description']}") + sys.exit(1) + + try: + model, example_inputs = get_model_and_inputs(args.model_name) + + # Choose export function based on --aoti_only flag + if args.aoti_only: + print("Using export_model_to_pure_aoti...") + export_model_to_pure_aoti(model, example_inputs) + else: + print("Using export_model_to_et_aoti...") + export_model_to_et_aoti(model, example_inputs) + + except ValueError as e: + print(f"Error: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index 4ed91cc545e..f09feca1584 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit 4ed91cc545e9ed7098e53747656eb7eff24eb305 +Subproject commit f09feca15849a790c05b3b7855e7c62ce26ba94b diff --git a/install_requirements.py b/install_requirements.py index cbae175e276..0e0084fe3dd 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -7,60 +7,21 @@ import argparse import os -import platform -import re import subprocess import sys +from install_utils import determine_torch_url, is_intel_mac_os, python_is_compatible -def python_is_compatible(): - # Scrape the version range from pyproject.toml, which should be in the current directory. - version_specifier = None - with open("pyproject.toml", "r") as file: - for line in file: - if line.startswith("requires-python"): - match = re.search(r'"([^"]*)"', line) - if match: - version_specifier = match.group(1) - break - - if not version_specifier: - print( - "WARNING: Skipping python version check: version range not found", - file=sys.stderr, - ) - return False - - # Install the packaging module if necessary. - try: - import packaging - except ImportError: - subprocess.run( - [sys.executable, "-m", "pip", "install", "packaging"], check=True - ) - # Compare the current python version to the range in version_specifier. Exits - # with status 1 if the version is not compatible, or with status 0 if the - # version is compatible or the logic itself fails. - try: - import packaging.specifiers - import packaging.version - - python_version = packaging.version.parse(platform.python_version()) - version_range = packaging.specifiers.SpecifierSet(version_specifier) - if python_version not in version_range: - print( - f'ERROR: ExecuTorch does not support python version {python_version}: must satisfy "{version_specifier}"', - file=sys.stderr, - ) - return False - except Exception as e: - print(f"WARNING: Skipping python version check: {e}", file=sys.stderr) - return True - - -# The pip repository that hosts nightly torch packages. -TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu" +# This will be dynamically set based on CUDA availability and CUDA backend enabled/disabled. +TORCH_NIGHTLY_URL_BASE = "https://download.pytorch.org/whl/nightly" +# Supported CUDA versions - modify this to add/remove supported versions +# Format: tuple of (major, minor) version numbers +SUPPORTED_CUDA_VERSIONS = [ + (12, 6), + (12, 8), + (12, 9), +] # Since ExecuTorch often uses main-branch features of pytorch, only the nightly # pip versions will have the required features. @@ -71,7 +32,10 @@ def python_is_compatible(): # # NOTE: If you're changing, make the corresponding change in .ci/docker/ci_commit_pins/pytorch.txt # by picking the hash from the same date in https://hud.pytorch.org/hud/pytorch/pytorch/nightly/ -NIGHTLY_VERSION = "dev20250906" +# +# NOTE: If you're changing, make the corresponding supported CUDA versions in +# SUPPORTED_CUDA_VERSIONS above if needed. +NIGHTLY_VERSION = "dev20250915" def install_requirements(use_pytorch_nightly): @@ -84,12 +48,15 @@ def install_requirements(use_pytorch_nightly): ) sys.exit(1) + # Determine the appropriate PyTorch URL based on CUDA delegate status + torch_url = determine_torch_url(TORCH_NIGHTLY_URL_BASE, SUPPORTED_CUDA_VERSIONS) + # pip packages needed by exir. TORCH_PACKAGE = [ # Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note # that we don't need to set any version number there because they have already # been installed on CI before this step, so pip won't reinstall them - f"torch==2.9.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch", + f"torch==2.10.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch", ] # Install the requirements for core ExecuTorch package. @@ -105,7 +72,7 @@ def install_requirements(use_pytorch_nightly): "requirements-dev.txt", *TORCH_PACKAGE, "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, ], check=True, ) @@ -147,10 +114,13 @@ def install_requirements(use_pytorch_nightly): def install_optional_example_requirements(use_pytorch_nightly): + # Determine the appropriate PyTorch URL based on CUDA delegate status + torch_url = determine_torch_url(TORCH_NIGHTLY_URL_BASE, SUPPORTED_CUDA_VERSIONS) + print("Installing torch domain libraries") DOMAIN_LIBRARIES = [ ( - f"torchvision==0.24.0.{NIGHTLY_VERSION}" + f"torchvision==0.25.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torchvision" ), @@ -165,7 +135,7 @@ def install_optional_example_requirements(use_pytorch_nightly): "install", *DOMAIN_LIBRARIES, "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, ], check=True, ) @@ -180,7 +150,7 @@ def install_optional_example_requirements(use_pytorch_nightly): "-r", "requirements-examples.txt", "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, "--upgrade-strategy", "only-if-needed", ], @@ -188,17 +158,6 @@ def install_optional_example_requirements(use_pytorch_nightly): ) -# Prebuilt binaries for Intel-based macOS are no longer available on PyPI; users must compile from source. -# PyTorch stopped building macOS x86_64 binaries since version 2.3.0 (January 2024). -def is_intel_mac_os(): - # Returns True if running on Intel macOS. - return platform.system().lower() == "darwin" and platform.machine().lower() in ( - "x86", - "x86_64", - "i386", - ) - - def main(args): parser = argparse.ArgumentParser() parser.add_argument( diff --git a/install_utils.py b/install_utils.py new file mode 100644 index 00000000000..19da1b2193b --- /dev/null +++ b/install_utils.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-25 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import platform +import re +import subprocess + + +def _is_cuda_enabled(): + """Check if CUDA delegate is enabled via CMAKE_ARGS environment variable.""" + cmake_args = os.environ.get("CMAKE_ARGS", "") + return "-DEXECUTORCH_BUILD_CUDA=ON" in cmake_args + + +def _cuda_version_to_pytorch_suffix(major, minor): + """ + Generate PyTorch CUDA wheel suffix from CUDA version numbers. + + Args: + major: CUDA major version (e.g., 12) + minor: CUDA minor version (e.g., 6) + + Returns: + PyTorch wheel suffix string (e.g., "cu126") + """ + return f"cu{major}{minor}" + + +def _get_cuda_version(supported_cuda_versions): + """ + Get the CUDA version installed on the system using nvcc command. + Returns a tuple (major, minor). + + Args: + supported_cuda_versions: List of supported CUDA versions as tuples + + Raises: + RuntimeError: if nvcc is not found or version cannot be parsed + """ + try: + # Get CUDA version from nvcc (CUDA compiler) + nvcc_result = subprocess.run( + ["nvcc", "--version"], capture_output=True, text=True, check=True + ) + # Parse nvcc output for CUDA version + # Output contains line like "Cuda compilation tools, release 12.6, V12.6.68" + match = re.search(r"release (\d+)\.(\d+)", nvcc_result.stdout) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + + # Check if the detected version is supported + if (major, minor) not in supported_cuda_versions: + available_versions = ", ".join( + [f"{maj}.{min}" for maj, min in supported_cuda_versions] + ) + raise RuntimeError( + f"Detected CUDA version {major}.{minor} is not supported. " + f"Only the following CUDA versions are supported: {available_versions}. " + f"Please install a supported CUDA version or try on CPU-only delegates." + ) + + return (major, minor) + else: + raise RuntimeError( + "CUDA delegate is enabled but could not parse CUDA version from nvcc output. " + "Please ensure CUDA is properly installed or try on CPU-only delegates." + ) + except FileNotFoundError: + raise RuntimeError( + "CUDA delegate is enabled but nvcc (CUDA compiler) is not found in PATH. " + "Please install CUDA toolkit or try on CPU-only delegates." + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"CUDA delegate is enabled but nvcc command failed with error: {e}. " + "Please ensure CUDA is properly installed or try on CPU-only delegates." + ) + + +def _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base): + """ + Get the appropriate PyTorch CUDA URL for the given CUDA version. + + Args: + cuda_version: tuple of (major, minor) version numbers + torch_nightly_url_base: Base URL for PyTorch nightly packages + + Returns: + URL string for PyTorch CUDA packages + """ + major, minor = cuda_version + # Generate CUDA suffix (version validation is already done in _get_cuda_version) + cuda_suffix = _cuda_version_to_pytorch_suffix(major, minor) + + return f"{torch_nightly_url_base}/{cuda_suffix}" + + +# Global variable for caching torch URL +_torch_url_cache = "" + + +def determine_torch_url(torch_nightly_url_base, supported_cuda_versions): + """ + Determine the appropriate PyTorch installation URL based on CUDA availability and CMAKE_ARGS. + Uses caching to avoid redundant CUDA detection and print statements. + + Args: + torch_nightly_url_base: Base URL for PyTorch nightly packages + supported_cuda_versions: List of supported CUDA versions as tuples + + Returns: + URL string for PyTorch packages + """ + global _torch_url_cache + + # Return cached URL if already determined + if _torch_url_cache: + return _torch_url_cache + + # Check if CUDA delegate is enabled + if not _is_cuda_enabled(): + print("CUDA delegate not enabled, using CPU-only PyTorch") + _torch_url_cache = f"{torch_nightly_url_base}/cpu" + return _torch_url_cache + + print("CUDA delegate enabled, detecting CUDA version...") + + # Get CUDA version + cuda_version = _get_cuda_version(supported_cuda_versions) + + major, minor = cuda_version + print(f"Detected CUDA version: {major}.{minor}") + + # Get appropriate PyTorch CUDA URL + torch_url = _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base) + print(f"Using PyTorch URL: {torch_url}") + + # Cache the result + _torch_url_cache = torch_url + return torch_url + + +# Prebuilt binaries for Intel-based macOS are no longer available on PyPI; users must compile from source. +# PyTorch stopped building macOS x86_64 binaries since version 2.3.0 (January 2024). +def is_intel_mac_os(): + # Returns True if running on Intel macOS. + return platform.system().lower() == "darwin" and platform.machine().lower() in ( + "x86", + "x86_64", + "i386", + ) + + +def python_is_compatible(): + # Scrape the version range from pyproject.toml, which should be in the current directory. + version_specifier = None + with open("pyproject.toml", "r") as file: + for line in file: + if line.startswith("requires-python"): + match = re.search(r'"([^"]*)"', line) + if match: + version_specifier = match.group(1) + break + + if not version_specifier: + print( + "WARNING: Skipping python version check: version range not found", + file=sys.stderr, + ) + return False + + # Install the packaging module if necessary. + try: + import packaging + except ImportError: + subprocess.run( + [sys.executable, "-m", "pip", "install", "packaging"], check=True + ) + # Compare the current python version to the range in version_specifier. Exits + # with status 1 if the version is not compatible, or with status 0 if the + # version is compatible or the logic itself fails. + try: + import packaging.specifiers + import packaging.version + + python_version = packaging.version.parse(platform.python_version()) + version_range = packaging.specifiers.SpecifierSet(version_specifier) + if python_version not in version_range: + print( + f'ERROR: ExecuTorch does not support python version {python_version}: must satisfy "{version_specifier}"', + file=sys.stderr, + ) + return False + except Exception as e: + print(f"WARNING: Skipping python version check: {e}", file=sys.stderr) + return True diff --git a/requirements-dev.txt b/requirements-dev.txt index 9df5e7b93ed..8c8f518a5ea 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ zstd # Imported by resolve_buck.py. certifi # Imported by resolve_buck.py. lintrunner==0.12.7 lintrunner-adapters==0.12.6 +patchelf diff --git a/requirements-examples.txt b/requirements-examples.txt index 0923cf8fefc..26ac1ad9279 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 -transformers == 4.53.1 +transformers == 4.52.4 diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 6c27e8ba616..ba9a686ccb9 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -53,6 +53,7 @@ set(EXECUTORCH_FOUND ON) include("${CMAKE_CURRENT_LIST_DIR}/ExecuTorchTargets.cmake") set(optional_lib_list + aoti_backend flatccrt etdump bundled_program diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index fb0dc0a4ade..fb993f7d5f0 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -160,6 +160,10 @@ define_overridable_option( OFF ) +define_overridable_option( + EXECUTORCH_BUILD_CUDA "Build the AOTI CUDA backend" BOOL OFF +) + if(EXECUTORCH_BUILD_ARM_BAREMETAL) set(_default_executorch_build_pthreadpool OFF) set(_default_executorch_build_cpuinfo OFF) @@ -317,6 +321,10 @@ check_required_options_on( EXECUTORCH_BUILD_PTHREADPOOL ) +check_required_options_on( + IF_ON EXECUTORCH_BUILD_CUDA REQUIRES EXECUTORCH_BUILD_EXTENSION_TENSOR +) + check_conflicting_options_on( IF_ON EXECUTORCH_BUILD_ARM_BAREMETAL CONFLICTS_WITH EXECUTORCH_BUILD_PTHREADPOOL EXECUTORCH_BUILD_CPUINFO