diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 8c9330d6f2c..e3a53c8bcb5 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -4d4abec80f03cd8fdefe1d9cb3a60d3690cd777e +53a2908a10f414a2f85caa06703a26a40e873869 diff --git a/.ci/scripts/test-cuda-build.sh b/.ci/scripts/test-cuda-build.sh new file mode 100755 index 00000000000..9981eb7ec87 --- /dev/null +++ b/.ci/scripts/test-cuda-build.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# Source the conda setup +bash .ci/scripts/setup-conda.sh +eval "$(conda shell.bash hook)" + +# Set up CONDA_RUN variable if not already set +# This is needed for compatibility with pytorch/test-infra workflows +export CONDA_RUN="${CONDA_RUN:-conda run --no-capture-output -p ${CONDA_PREFIX:-$HOME/miniconda3/envs/ci}}" + +CUDA_VERSION=${1:-"12.6"} + +echo "=== Testing ExecutorTorch CUDA ${CUDA_VERSION} Build ===" + +# Function to build and test ExecutorTorch with CUDA support +test_executorch_cuda_build() { + local cuda_version=$1 + + echo "Building ExecutorTorch with CUDA ${cuda_version} support..." + echo "ExecutorTorch will automatically detect CUDA and install appropriate PyTorch wheel" + + # Check available resources before starting + echo "=== System Information ===" + echo "Available memory: $(free -h | grep Mem | awk '{print $2}')" + echo "Available disk space: $(df -h . | tail -1 | awk '{print $4}')" + echo "CPU cores: $(nproc)" + echo "CUDA version check:" + nvcc --version || echo "nvcc not found" + nvidia-smi || echo "nvidia-smi not found" + + # Set CMAKE_ARGS to enable CUDA build - ExecutorTorch will handle PyTorch installation automatically + export CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + + echo "=== Starting ExecutorTorch Installation ===" + # Install ExecutorTorch with CUDA support with timeout and error handling + timeout 5400 ./install_executorch.sh || { + local exit_code=$? + echo "ERROR: install_executorch.sh failed with exit code: $exit_code" + if [ $exit_code -eq 124 ]; then + echo "ERROR: Installation timed out after 90 minutes" + fi + exit $exit_code + } + + echo "SUCCESS: ExecutorTorch CUDA build completed" + + # Verify the installation + echo "=== Verifying ExecutorTorch CUDA Installation ===" + + # Test that ExecutorTorch was built successfully + ${CONDA_RUN} python -c " +import executorch +print('SUCCESS: ExecutorTorch imported successfully') +" + + # Test CUDA availability and show details + ${CONDA_RUN} python -c " +try: + import torch + print('INFO: PyTorch version:', torch.__version__) + print('INFO: CUDA available:', torch.cuda.is_available()) + + if torch.cuda.is_available(): + print('SUCCESS: CUDA is available for ExecutorTorch') + print('INFO: CUDA version:', torch.version.cuda) + print('INFO: GPU device count:', torch.cuda.device_count()) + print('INFO: Current GPU device:', torch.cuda.current_device()) + print('INFO: GPU device name:', torch.cuda.get_device_name()) + + # Test basic CUDA tensor operation + device = torch.device('cuda') + x = torch.randn(10, 10).to(device) + y = torch.randn(10, 10).to(device) + z = torch.mm(x, y) + print('SUCCESS: CUDA tensor operation completed on device:', z.device) + print('INFO: Result tensor shape:', z.shape) + + print('SUCCESS: ExecutorTorch CUDA integration verified') + else: + print('WARNING: CUDA not detected, but ExecutorTorch built successfully') + exit(1) +except Exception as e: + print('ERROR: ExecutorTorch CUDA test failed:', e) + exit(1) +" + + echo "SUCCESS: ExecutorTorch CUDA ${cuda_version} build and verification completed successfully" +} + +# Main execution +echo "Current working directory: $(pwd)" +echo "Directory contents:" +ls -la + +# Run the CUDA build test +test_executorch_cuda_build "${CUDA_VERSION}" diff --git a/.ci/scripts/test-cuda-export-aoti.sh b/.ci/scripts/test-cuda-export-aoti.sh new file mode 100755 index 00000000000..6ea701b8f4b --- /dev/null +++ b/.ci/scripts/test-cuda-export-aoti.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# shellcheck source=/dev/null +source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" + +CUDA_VERSION=${1:-"12.6"} + +echo "=== Testing ExecutorTorch CUDA AOTI Export ${CUDA_VERSION} ===" + +# Function to test CUDA AOTI export functionality +test_cuda_aoti_export() { + local cuda_version=$1 + + echo "Testing CUDA AOTI export with CUDA ${cuda_version} support..." + + # Check available resources before starting + echo "=== System Information ===" + echo "Available memory: $(free -h | grep Mem | awk '{print $2}')" + echo "Available disk space: $(df -h . | tail -1 | awk '{print $4}')" + echo "CPU cores: $(nproc)" + echo "CUDA version check:" + nvcc --version || echo "nvcc not found" + nvidia-smi || echo "nvidia-smi not found" + + # Set up environment for CUDA builds + export CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + + echo "=== Installing ExecutorTorch with CUDA support ===" + # Install ExecutorTorch with CUDA support with timeout and error handling + timeout 5400 ./install_executorch.sh || { + local exit_code=$? + echo "ERROR: install_executorch.sh failed with exit code: $exit_code" + if [ $exit_code -eq 124 ]; then + echo "ERROR: Installation timed out after 90 minutes" + fi + exit $exit_code + } + + echo "SUCCESS: ExecutorTorch CUDA installation completed" + + # Verify the installation + echo "=== Verifying ExecutorTorch CUDA Installation ===" + + # Test that ExecutorTorch was built successfully + python -c " +import executorch +print('SUCCESS: ExecutorTorch imported successfully') +" + + # Test CUDA availability and show details + python -c " +try: + import torch + print('INFO: PyTorch version:', torch.__version__) + print('INFO: CUDA available:', torch.cuda.is_available()) + + if torch.cuda.is_available(): + print('SUCCESS: CUDA is available for ExecutorTorch') + print('INFO: CUDA version:', torch.version.cuda) + print('INFO: GPU device count:', torch.cuda.device_count()) + print('INFO: Current GPU device:', torch.cuda.current_device()) + print('INFO: GPU device name:', torch.cuda.get_device_name()) + + # Test basic CUDA tensor operation + device = torch.device('cuda') + x = torch.randn(10, 10).to(device) + y = torch.randn(10, 10).to(device) + z = torch.mm(x, y) + print('SUCCESS: CUDA tensor operation completed on device:', z.device) + print('INFO: Result tensor shape:', z.shape) + + print('SUCCESS: ExecutorTorch CUDA integration verified') + else: + print('WARNING: CUDA not detected, but ExecutorTorch built successfully') + exit(1) +except Exception as e: + print('ERROR: ExecutorTorch CUDA test failed:', e) + exit(1) +" + + echo "=== Running CUDA AOTI Export Tests ===" + # Run the CUDA AOTI export tests using the Python script + python .ci/scripts/test_cuda_export_aoti.py \ + --models linear conv2d add resnet18 \ + --export-mode export_aoti_only \ + --timeout 600 \ + --cleanup + + echo "SUCCESS: ExecutorTorch CUDA AOTI export ${cuda_version} tests completed successfully" +} + +# Main execution +echo "Current working directory: $(pwd)" +echo "Directory contents:" +ls -la + +# Run the CUDA AOTI export test +test_cuda_aoti_export "${CUDA_VERSION}" diff --git a/.ci/scripts/test_cuda_export_aoti.py b/.ci/scripts/test_cuda_export_aoti.py new file mode 100755 index 00000000000..3748dc5fe33 --- /dev/null +++ b/.ci/scripts/test_cuda_export_aoti.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test script for CUDA AOTI export functionality. +This script tests basic CUDA export functionality for a subset of models: +linear, conv2d, add, and resnet18. +""" + +import argparse +import os +import subprocess +import sys +from typing import List, Optional + + +def run_command( + cmd: List[str], cwd: Optional[str] = None, timeout: int = 300 +) -> subprocess.CompletedProcess: + """Run a command with proper error handling and timeout.""" + print(f"Running command: {' '.join(cmd)}") + if cwd: + print(f"Working directory: {cwd}") + + try: + result = subprocess.run( + cmd, + cwd=cwd, + capture_output=True, + text=True, + timeout=timeout, + check=False, # We'll handle the return code ourselves + ) + + if result.stdout: + print("STDOUT:") + print(result.stdout) + if result.stderr: + print("STDERR:") + print(result.stderr) + + return result + except subprocess.TimeoutExpired as e: + print(f"ERROR: Command timed out after {timeout} seconds") + raise e + except Exception as e: + print(f"ERROR: Failed to run command: {e}") + raise e + + +def test_cuda_export( + model_name: str, export_mode: str = "export_aoti_only", timeout: int = 300 +) -> bool: + """Test CUDA export for a specific model.""" + print(f"\n{'='*60}") + print(f"Testing CUDA export for model: {model_name}") + print(f"Export mode: {export_mode}") + print(f"{'='*60}") + + try: + # Run the export using export_aoti.py + cmd = ["python", "export_aoti.py", model_name] + if export_mode == "export_aoti_only": + cmd.append("--aoti_only") + + result = run_command(cmd, timeout=timeout) + + if result.returncode == 0: + print(f"SUCCESS: {model_name} export completed successfully") + return True + else: + print( + f"ERROR: {model_name} export failed with return code {result.returncode}" + ) + return False + + except subprocess.TimeoutExpired: + print(f"ERROR: {model_name} export timed out after {timeout} seconds") + return False + except Exception as e: + print(f"ERROR: {model_name} export failed with exception: {e}") + return False + + +def cleanup_temp_files(): + """Clean up temporary files generated during export.""" + print("Cleaning up temporary files...") + + # List of file patterns to clean up + cleanup_patterns = [ + "*.cubin", + "*.pte", + "*.so", + "*kernel_metadata.json", + "*kernel.cpp", + "*wrapper_metadata.json", + "*wrapper.cpp", + "*wrapper.json", + "aoti_intermediate_output.txt", + ] + + # Remove files matching patterns + for pattern in cleanup_patterns: + try: + import glob + + files = glob.glob(pattern) + for file in files: + if os.path.isfile(file): + os.remove(file) + print(f"Removed file: {file}") + except Exception as e: + print(f"Warning: Failed to remove {pattern}: {e}") + + # Remove temporary directories created by wrappers + try: + import glob + + for wrapper_file in glob.glob("*wrapper.cpp"): + basename = wrapper_file.replace("wrapper.cpp", "") + if os.path.isdir(basename): + import shutil + + shutil.rmtree(basename) + print(f"Removed directory: {basename}") + except Exception as e: + print(f"Warning: Failed to remove wrapper directories: {e}") + + print("Cleanup completed.") + + +def main(): + """Main function to test CUDA export for specified models.""" + parser = argparse.ArgumentParser( + description="Test CUDA AOTI export functionality", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--models", + nargs="+", + default=["linear", "conv2d", "add", "resnet18"], + help="List of models to test (default: linear, conv2d, add, resnet18)", + ) + + parser.add_argument( + "--export-mode", + choices=["export_aoti_only", "full"], + default="export_aoti_only", + help="Export mode: export_aoti_only (AOTI only) or full (full pipeline)", + ) + + parser.add_argument( + "--timeout", + type=int, + default=300, + help="Timeout for each model export in seconds (default: 300)", + ) + + parser.add_argument( + "--cleanup", + action="store_true", + default=True, + help="Clean up temporary files after testing (default: True)", + ) + + args = parser.parse_args() + + print("CUDA AOTI Export Test") + print("=" * 60) + print(f"Models to test: {args.models}") + print(f"Export mode: {args.export_mode}") + print(f"Timeout per model: {args.timeout} seconds") + print(f"Cleanup enabled: {args.cleanup}") + print("=" * 60) + + # Check if we're in the correct directory (should have export_aoti.py) + if not os.path.exists("export_aoti.py"): + print("ERROR: export_aoti.py not found in current directory") + print("Please run this script from the executorch root directory") + sys.exit(1) + + # Test each model + successful_models = [] + failed_models = [] + + for model in args.models: + # Clean up before each test + if args.cleanup: + cleanup_temp_files() + + success = test_cuda_export(model, args.export_mode, args.timeout) + + if success: + successful_models.append(model) + else: + failed_models.append(model) + + # Final cleanup + if args.cleanup: + cleanup_temp_files() + + # Print summary + print("\n" + "=" * 60) + print("CUDA AOTI Export Test Summary") + print("=" * 60) + print(f"Total models tested: {len(args.models)}") + print(f"Successful exports: {len(successful_models)}") + print(f"Failed exports: {len(failed_models)}") + + if successful_models: + print(f"Successful models: {', '.join(successful_models)}") + + if failed_models: + print(f"Failed models: {', '.join(failed_models)}") + print("\nERROR: One or more model exports failed!") + sys.exit(1) + else: + print("\nSUCCESS: All model exports completed successfully!") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/test-backend-cuda.yml b/.github/workflows/test-backend-cuda.yml new file mode 100644 index 00000000000..bc8063b5b73 --- /dev/null +++ b/.github/workflows/test-backend-cuda.yml @@ -0,0 +1,68 @@ +# Test ExecutorTorch CUDA AOTI Export Functionality +# This workflow tests whether ExecutorTorch can successfully export models using CUDA AOTI +# across different CUDA versions (12.6, 12.8, 12.9) for a subset of models: +# linear, conv2d, add, and resnet18 +# +# The test focuses on export-only functionality and verifies that no errors are raised +# during the AOTI export process. + +name: Test CUDA AOTI Export + +on: + pull_request: + push: + branches: + - main + - release/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: false + +jobs: + test-cuda-aoti-export: + strategy: + fail-fast: false + matrix: + cuda-version: ["12.6", "12.8", "12.9"] + + name: test-executorch-cuda-aoti-export-${{ matrix.cuda-version }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 120 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda-version }} + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + if [ -n "$CONDA_ENV" ]; then + conda activate "${CONDA_ENV}" + fi + + # Test ExecutorTorch CUDA AOTI export - ExecutorTorch will automatically detect CUDA version + # and install the appropriate PyTorch wheel when CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + PYTHON_EXECUTABLE=python bash .ci/scripts/test-cuda-export-aoti.sh "${{ matrix.cuda-version }}" + + # This job will fail if any of the CUDA AOTI export tests fail + check-all-cuda-aoti-exports: + needs: test-cuda-aoti-export + runs-on: ubuntu-latest + if: always() + steps: + - name: Check if all CUDA AOTI export tests succeeded + run: | + if [[ "${{ needs.test-cuda-aoti-export.result }}" != "success" ]]; then + echo "ERROR: One or more ExecutorTorch CUDA AOTI export tests failed!" + echo "CUDA AOTI export test results: ${{ needs.test-cuda-aoti-export.result }}" + exit 1 + else + echo "SUCCESS: All ExecutorTorch CUDA AOTI export tests (12.6, 12.8, 12.9) completed successfully!" + fi \ No newline at end of file diff --git a/.github/workflows/test-cuda-builds.yml b/.github/workflows/test-cuda-builds.yml new file mode 100644 index 00000000000..13fbd427310 --- /dev/null +++ b/.github/workflows/test-cuda-builds.yml @@ -0,0 +1,66 @@ +# Test ExecutorTorch CUDA Build Compatibility +# This workflow tests whether ExecutorTorch can be successfully built with CUDA support +# across different CUDA versions (12.6, 12.8, 12.9) using the command: +# CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh +# +# Note: ExecutorTorch automatically detects the system CUDA version using nvcc and +# installs the appropriate PyTorch wheel. No manual CUDA/PyTorch installation needed. + +name: Test CUDA Builds + +on: + pull_request: + push: + branches: + - main + - release/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: false + +jobs: + test-cuda-builds: + strategy: + fail-fast: false + matrix: + cuda-version: ["12.6", "12.8", "12.9"] + + name: test-executorch-cuda-build-${{ matrix.cuda-version }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: ${{ matrix.cuda-version }} + docker-image: nvidia/cuda:${{ matrix.cuda-version }}.0-devel-ubuntu22.04 + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + # This is needed to get the prebuilt PyTorch wheel from S3 + ${CONDA_RUN} --no-capture-output pip install awscli==1.37.21 + + # Test ExecutorTorch CUDA build - ExecutorTorch will automatically detect CUDA version + # and install the appropriate PyTorch wheel when CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" + source .ci/scripts/test-cuda-build.sh "${{ matrix.cuda-version }}" + + # This job will fail if any of the CUDA versions fail + check-all-cuda-builds: + needs: test-cuda-builds + runs-on: ubuntu-latest + if: always() + steps: + - name: Check if all CUDA builds succeeded + run: | + if [[ "${{ needs.test-cuda-builds.result }}" != "success" ]]; then + echo "ERROR: One or more ExecutorTorch CUDA builds failed!" + echo "CUDA build results: ${{ needs.test-cuda-builds.result }}" + exit 1 + else + echo "SUCCESS: All ExecutorTorch CUDA builds (12.6, 12.8, 12.9) completed successfully!" + fi diff --git a/.gitignore b/.gitignore index b166f8c9512..63d0bdce964 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,14 @@ tokenizer.json !test_bpe_tokenizer.bin !test_tiktoken_tokenizer.model +# AOTI temporary files +*.cubin +*kernel_metadata.json +*kernel.cpp +*wrapper_metadata.json +*wrapper.cpp +*wrapper.json + # Editor temporaries *.idea *.sw[a-z] diff --git a/backends/cuda/__init__.py b/backends/cuda/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/cuda/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py new file mode 100644 index 00000000000..99599de6b6c --- /dev/null +++ b/backends/cuda/cuda_backend.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import os +import typing + +from subprocess import check_call +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = {} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +class CudaBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + + named_data_store = NamedDataStore() + + # Move the edge_program from CPU to CUDA for aoti compile + cuda_edge_program = move_to_device_pass(edge_program, "cuda") + + edge_program_module = cuda_edge_program.module() + args, kwargs = cuda_edge_program.example_inputs + + output_path = os.path.join(os.getcwd(), "aoti.so") + + options: dict[str, typing.Any] = { + "aot_inductor.embed_kernel_binary": True, + "aot_inductor.link_libtorch": False, + "aot_inductor.package_constants_in_so": True, + "aot_inductor.output_path": output_path, + "aot_inductor.debug_compile": True, + "aot_inductor.force_mmap_weights": False, + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "max_autotune_conv_backends": "TRITON", + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py new file mode 100644 index 00000000000..227d13ba093 --- /dev/null +++ b/backends/cuda/cuda_partitioner.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from torch.export.exported_program import ExportedProgram + + +@final +class CudaPartitioner(Partitioner): + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + tag = f"tag0" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip decomposing all ops and let the AOT compiler handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index dd8d97d66ac..d0225437c99 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -268,7 +268,9 @@ def _partition_and_lower_one_graph_module( """ Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module. """ - for tag, delegation_spec in partition_result.partition_tags.items(): + for idx, (tag, delegation_spec) in enumerate( + partition_result.partition_tags.items() + ): # Create partition with nodes containing this tag. There should only be # one contained submodule per tag node_list = _get_node_list_with_same_tag( @@ -311,6 +313,7 @@ def _partition_and_lower_one_graph_module( tag, call_module_node, is_submodule, + idx == 0, ) lowered_submodule = to_backend( @@ -452,7 +455,9 @@ def _create_partitions_in_graph_module( is_submodule: bool, ) -> Dict[str, List[torch.fx.Node]]: backend_id_to_submodule_name = {} - for tag, delegation_spec in partition_result.partition_tags.items(): + for idx, (tag, delegation_spec) in enumerate( + partition_result.partition_tags.items() + ): # Create partition with nodes containing this tag. There should only be # one contained submodule per tag node_list = _get_node_list_with_same_tag( @@ -492,6 +497,7 @@ def _create_partitions_in_graph_module( tag, call_module_node, is_submodule, + idx == 0, ) call_module_node.meta["backend_id"] = delegation_spec.backend_id call_module_node.meta["compile_spec"] = delegation_spec.compile_specs @@ -720,6 +726,8 @@ def to_backend( fake_edge_program = copy.deepcopy(edge_program) partitioner_result = partitioner_instance(fake_edge_program) tagged_exported_program = partitioner_result.tagged_exported_program + tagged_exported_program.example_inputs = edge_program.example_inputs + method_to_tagged_exported_program[method_name] = tagged_exported_program # Check that the partitioner did not modify the original graph diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 0618871bd40..eb84d508c2c 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -176,6 +176,7 @@ def emit_program( ) emitter.run() + plans.append(emitter.plan()) debug_handle_map[name] = emitter.debug_handle_map diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 61414990703..3c5ee5d36b0 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -682,6 +682,7 @@ def create_exported_program_from_submodule( tag: str, call_module_node: torch.fx.Node, is_submodule: bool, + is_first_partition: bool = False, ) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: """ Creates an ExportedProgram from the given submodule using the parameters and buffers @@ -720,6 +721,11 @@ def create_exported_program_from_submodule( in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] + # only the example inputs of first parition equals to the example inputs of the owning program + submodule_exmaple_inputs = ( + owning_program.example_inputs if is_first_partition else None + ) + return ( ExportedProgram( root=submodule, @@ -735,6 +741,7 @@ def create_exported_program_from_submodule( ), ) ], + example_inputs=submodule_exmaple_inputs, constants=subgraph_constants, verifiers=[owning_program.verifier], ), diff --git a/exir/program/_program.py b/exir/program/_program.py index a33d715ca3b..c740bbcb7b3 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1697,6 +1697,7 @@ def to_executorch( # noqa (FLAKE8) C901 after it has been transformed to the ExecuTorch backend. """ config = config if config else ExecutorchBackendConfig() + execution_programs: Dict[str, ExportedProgram] = {} for name, program in self._edge_programs.items(): if config.do_quant_fusion_and_const_prop: diff --git a/install_requirements.py b/install_requirements.py index cbae175e276..409ed083970 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -7,60 +7,22 @@ import argparse import os -import platform -import re import subprocess import sys - -def python_is_compatible(): - # Scrape the version range from pyproject.toml, which should be in the current directory. - version_specifier = None - with open("pyproject.toml", "r") as file: - for line in file: - if line.startswith("requires-python"): - match = re.search(r'"([^"]*)"', line) - if match: - version_specifier = match.group(1) - break - - if not version_specifier: - print( - "WARNING: Skipping python version check: version range not found", - file=sys.stderr, - ) - return False - - # Install the packaging module if necessary. - try: - import packaging - except ImportError: - subprocess.run( - [sys.executable, "-m", "pip", "install", "packaging"], check=True - ) - # Compare the current python version to the range in version_specifier. Exits - # with status 1 if the version is not compatible, or with status 0 if the - # version is compatible or the logic itself fails. - try: - import packaging.specifiers - import packaging.version - - python_version = packaging.version.parse(platform.python_version()) - version_range = packaging.specifiers.SpecifierSet(version_specifier) - if python_version not in version_range: - print( - f'ERROR: ExecuTorch does not support python version {python_version}: must satisfy "{version_specifier}"', - file=sys.stderr, - ) - return False - except Exception as e: - print(f"WARNING: Skipping python version check: {e}", file=sys.stderr) - return True - +from install_utils import determine_torch_url, is_intel_mac_os, python_is_compatible # The pip repository that hosts nightly torch packages. -TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu" +# This will be dynamically set based on CUDA availability and CUDA backend enabled/disabled. +TORCH_NIGHTLY_URL_BASE = "https://download.pytorch.org/whl/nightly" +# Supported CUDA versions - modify this to add/remove supported versions +# Format: tuple of (major, minor) version numbers +SUPPORTED_CUDA_VERSIONS = [ + (12, 6), + (12, 8), + (12, 9), +] # Since ExecuTorch often uses main-branch features of pytorch, only the nightly # pip versions will have the required features. @@ -71,7 +33,10 @@ def python_is_compatible(): # # NOTE: If you're changing, make the corresponding change in .ci/docker/ci_commit_pins/pytorch.txt # by picking the hash from the same date in https://hud.pytorch.org/hud/pytorch/pytorch/nightly/ -NIGHTLY_VERSION = "dev20250906" +# +# NOTE: If you're changing, make the corresponding supported CUDA versions in +# SUPPORTED_CUDA_VERSIONS above if needed. +NIGHTLY_VERSION = "dev20250915" def install_requirements(use_pytorch_nightly): @@ -84,12 +49,15 @@ def install_requirements(use_pytorch_nightly): ) sys.exit(1) + # Determine the appropriate PyTorch URL based on CUDA delegate status + torch_url = determine_torch_url(TORCH_NIGHTLY_URL_BASE, SUPPORTED_CUDA_VERSIONS) + # pip packages needed by exir. TORCH_PACKAGE = [ # Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note # that we don't need to set any version number there because they have already # been installed on CI before this step, so pip won't reinstall them - f"torch==2.9.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch", + f"torch==2.10.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch", ] # Install the requirements for core ExecuTorch package. @@ -105,7 +73,7 @@ def install_requirements(use_pytorch_nightly): "requirements-dev.txt", *TORCH_PACKAGE, "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, ], check=True, ) @@ -147,10 +115,13 @@ def install_requirements(use_pytorch_nightly): def install_optional_example_requirements(use_pytorch_nightly): + # Determine the appropriate PyTorch URL based on CUDA delegate status + torch_url = determine_torch_url(TORCH_NIGHTLY_URL_BASE, SUPPORTED_CUDA_VERSIONS) + print("Installing torch domain libraries") DOMAIN_LIBRARIES = [ ( - f"torchvision==0.24.0.{NIGHTLY_VERSION}" + f"torchvision==0.25.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torchvision" ), @@ -165,7 +136,7 @@ def install_optional_example_requirements(use_pytorch_nightly): "install", *DOMAIN_LIBRARIES, "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, ], check=True, ) @@ -180,7 +151,7 @@ def install_optional_example_requirements(use_pytorch_nightly): "-r", "requirements-examples.txt", "--extra-index-url", - TORCH_NIGHTLY_URL, + torch_url, "--upgrade-strategy", "only-if-needed", ], @@ -188,17 +159,6 @@ def install_optional_example_requirements(use_pytorch_nightly): ) -# Prebuilt binaries for Intel-based macOS are no longer available on PyPI; users must compile from source. -# PyTorch stopped building macOS x86_64 binaries since version 2.3.0 (January 2024). -def is_intel_mac_os(): - # Returns True if running on Intel macOS. - return platform.system().lower() == "darwin" and platform.machine().lower() in ( - "x86", - "x86_64", - "i386", - ) - - def main(args): parser = argparse.ArgumentParser() parser.add_argument( diff --git a/install_utils.py b/install_utils.py new file mode 100644 index 00000000000..fdd6c4bd93c --- /dev/null +++ b/install_utils.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-25 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import platform +import re +import subprocess +import sys + + +def _is_cuda_enabled(): + """Check if CUDA delegate is enabled via CMAKE_ARGS environment variable.""" + cmake_args = os.environ.get("CMAKE_ARGS", "") + return "-DEXECUTORCH_BUILD_CUDA=ON" in cmake_args + + +def _cuda_version_to_pytorch_suffix(major, minor): + """ + Generate PyTorch CUDA wheel suffix from CUDA version numbers. + + Args: + major: CUDA major version (e.g., 12) + minor: CUDA minor version (e.g., 6) + + Returns: + PyTorch wheel suffix string (e.g., "cu126") + """ + return f"cu{major}{minor}" + + +def _get_cuda_version(supported_cuda_versions): + """ + Get the CUDA version installed on the system using nvcc command. + Returns a tuple (major, minor). + + Args: + supported_cuda_versions: List of supported CUDA versions as tuples + + Raises: + RuntimeError: if nvcc is not found or version cannot be parsed + """ + try: + # Get CUDA version from nvcc (CUDA compiler) + nvcc_result = subprocess.run( + ["nvcc", "--version"], capture_output=True, text=True, check=True + ) + # Parse nvcc output for CUDA version + # Output contains line like "Cuda compilation tools, release 12.6, V12.6.68" + match = re.search(r"release (\d+)\.(\d+)", nvcc_result.stdout) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + + # Check if the detected version is supported + if (major, minor) not in supported_cuda_versions: + available_versions = ", ".join( + [f"{maj}.{min}" for maj, min in supported_cuda_versions] + ) + raise RuntimeError( + f"Detected CUDA version {major}.{minor} is not supported. " + f"Only the following CUDA versions are supported: {available_versions}. " + f"Please install a supported CUDA version or try on CPU-only delegates." + ) + + return (major, minor) + else: + raise RuntimeError( + "CUDA delegate is enabled but could not parse CUDA version from nvcc output. " + "Please ensure CUDA is properly installed or try on CPU-only delegates." + ) + except FileNotFoundError: + raise RuntimeError( + "CUDA delegate is enabled but nvcc (CUDA compiler) is not found in PATH. " + "Please install CUDA toolkit or try on CPU-only delegates." + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"CUDA delegate is enabled but nvcc command failed with error: {e}. " + "Please ensure CUDA is properly installed or try on CPU-only delegates." + ) + + +def _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base): + """ + Get the appropriate PyTorch CUDA URL for the given CUDA version. + + Args: + cuda_version: tuple of (major, minor) version numbers + torch_nightly_url_base: Base URL for PyTorch nightly packages + + Returns: + URL string for PyTorch CUDA packages + """ + major, minor = cuda_version + # Generate CUDA suffix (version validation is already done in _get_cuda_version) + cuda_suffix = _cuda_version_to_pytorch_suffix(major, minor) + + return f"{torch_nightly_url_base}/{cuda_suffix}" + + +# Global variable for caching torch URL +_torch_url_cache = "" + + +def determine_torch_url(torch_nightly_url_base, supported_cuda_versions): + """ + Determine the appropriate PyTorch installation URL based on CUDA availability and CMAKE_ARGS. + Uses caching to avoid redundant CUDA detection and print statements. + + Args: + torch_nightly_url_base: Base URL for PyTorch nightly packages + supported_cuda_versions: List of supported CUDA versions as tuples + + Returns: + URL string for PyTorch packages + """ + global _torch_url_cache + + # Return cached URL if already determined + if _torch_url_cache: + return _torch_url_cache + + # Check if CUDA delegate is enabled + if not _is_cuda_enabled(): + print("CUDA delegate not enabled, using CPU-only PyTorch") + _torch_url_cache = f"{torch_nightly_url_base}/cpu" + return _torch_url_cache + + print("CUDA delegate enabled, detecting CUDA version...") + + # Get CUDA version + cuda_version = _get_cuda_version(supported_cuda_versions) + + major, minor = cuda_version + print(f"Detected CUDA version: {major}.{minor}") + + # Get appropriate PyTorch CUDA URL + torch_url = _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base) + print(f"Using PyTorch URL: {torch_url}") + + # Cache the result + _torch_url_cache = torch_url + return torch_url + + +# Prebuilt binaries for Intel-based macOS are no longer available on PyPI; users must compile from source. +# PyTorch stopped building macOS x86_64 binaries since version 2.3.0 (January 2024). +def is_intel_mac_os(): + # Returns True if running on Intel macOS. + return platform.system().lower() == "darwin" and platform.machine().lower() in ( + "x86", + "x86_64", + "i386", + ) + + +def python_is_compatible(): + # Scrape the version range from pyproject.toml, which should be in the current directory. + version_specifier = None + with open("pyproject.toml", "r") as file: + for line in file: + if line.startswith("requires-python"): + match = re.search(r'"([^"]*)"', line) + if match: + version_specifier = match.group(1) + break + + if not version_specifier: + print( + "WARNING: Skipping python version check: version range not found", + file=sys.stderr, + ) + return False + + # Install the packaging module if necessary. + try: + import packaging + except ImportError: + subprocess.run( + [sys.executable, "-m", "pip", "install", "packaging"], check=True + ) + # Compare the current python version to the range in version_specifier. Exits + # with status 1 if the version is not compatible, or with status 0 if the + # version is compatible or the logic itself fails. + try: + import packaging.specifiers + import packaging.version + + python_version = packaging.version.parse(platform.python_version()) + version_range = packaging.specifiers.SpecifierSet(version_specifier) + if python_version not in version_range: + print( + f'ERROR: ExecuTorch does not support python version {python_version}: must satisfy "{version_specifier}"', + file=sys.stderr, + ) + return False + except Exception as e: + print(f"WARNING: Skipping python version check: {e}", file=sys.stderr) + return True