Skip to content

Commit 3b02829

Browse files
committed
refactor aoti-driven backends
1 parent 32c14b1 commit 3b02829

30 files changed

+485
-333
lines changed

CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
cmake_minimum_required(VERSION 3.29)
5151
project(executorch)
5252

53-
5453
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
5554

5655
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
@@ -592,9 +591,13 @@ if(EXECUTORCH_BUILD_CORTEX_M)
592591
list(APPEND _executorch_backends coretex_m_backend)
593592
endif()
594593

595-
if(EXECUTORCH_BUILD_AOTI)
594+
if(EXECUTORCH_BUILD_CUDA)
595+
# Build common AOTI functionality (required for CUDA)
596596
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti)
597-
list(APPEND _executorch_backends aoti_backend)
597+
# Build CUDA-specific AOTI functionality
598+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/aoti/cuda)
599+
# Add aoti_cuda to backends - it already depends on aoti_common
600+
list(APPEND _executorch_backends aoti_cuda)
598601
endif()
599602

600603
if(EXECUTORCH_BUILD_EXTENSION_APPLE)

backends/aoti/CMakeLists.txt

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,34 @@ if(NOT EXECUTORCH_ROOT)
2121
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2222
endif()
2323

24-
find_package(CUDAToolkit REQUIRED)
25-
2624
# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI
2725
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2826
find_package_torch()
2927

30-
set(_aoti_sources
31-
runtime/aoti_backend.cpp
32-
runtime/aoti_model_container.cpp
33-
runtime/shims/memory.cpp
34-
runtime/shims/tensor_attribute.cpp
35-
runtime/shims/utils.cpp)
36-
add_library(aoti_backend STATIC ${_aoti_sources})
28+
# Common AOTI functionality (non-CUDA)
29+
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp utils.cpp)
30+
add_library(aoti_common STATIC ${_aoti_common_sources})
3731
target_include_directories(
38-
aoti_backend
39-
PUBLIC
40-
${CUDAToolkit_INCLUDE_DIRS}
41-
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
42-
$<INSTALL_INTERFACE:include>
43-
# PyTorch AOTI headers from ExecutorTorch's torch detection
44-
${TORCH_INCLUDE_DIRS}
32+
aoti_common
33+
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
34+
# PyTorch AOTI headers from ExecutorTorch's torch detection
35+
${TORCH_INCLUDE_DIRS}
4536
)
46-
target_compile_options(aoti_backend PUBLIC -fexceptions -frtti -fPIC)
37+
target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4738
# Ensure symbols are exported properly
48-
target_link_options(aoti_backend PUBLIC -Wl,--export-dynamic)
39+
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4940

50-
# Link against CUDA::cudart, PyTorch libraries and standard libraries
41+
# Link against PyTorch libraries and standard libraries
5142
target_link_libraries(
52-
aoti_backend
53-
PUBLIC
54-
extension_tensor
55-
CUDA::cudart
56-
${CMAKE_DL_LIBS}
57-
# Link PyTorch libraries for AOTI CUDA functions
58-
${TORCH_LIBRARIES}
43+
aoti_common
44+
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
45+
# Link PyTorch libraries for AOTI functions
46+
${TORCH_LIBRARIES}
5947
)
60-
# If you need other CUDA libraries, link them similarly:
61-
# target_link_libraries(aoti_backend PUBLIC CUDA::cublas CUDA::cufft ...)
62-
# If you have a custom function, keep it
63-
executorch_target_link_options_shared_lib(aoti_backend)
48+
executorch_target_link_options_shared_lib(aoti_common)
49+
6450
install(
65-
TARGETS aoti_backend
51+
TARGETS aoti_common
6652
EXPORT ExecuTorchTargets
6753
DESTINATION lib
6854
)

backends/aoti/README.md

Lines changed: 0 additions & 2 deletions
This file was deleted.
File renamed without changes.

backends/aoti/runtime/aoti_model_container.h renamed to backends/aoti/aoti_model_container.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <executorch/extension/tensor/tensor.h>
1212
#include <executorch/runtime/core/error.h>
13-
#include "shims/memory.h"
13+
#include "cuda/runtime/shims/memory.h"
1414

1515
namespace executorch {
1616
namespace backends {

backends/aoti/runtime/shims/tensor_attribute.cpp renamed to backends/aoti/common_shims.cpp

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,31 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include "tensor_attribute.h"
9+
#include "common_shims.h"
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
#include <cstdio>
13+
#include <fstream>
1014
#include <iostream>
11-
#include "utils.h"
15+
#include <stdexcept>
1216

1317
namespace executorch {
1418
namespace backends {
1519
namespace aoti {
1620

21+
namespace internal {
22+
// Constants for file operations
23+
const char* const TENSOR_OUTPUT_FILENAME =
24+
"/home/gasoonjia/executorch/aoti_intermediate_output.txt";
25+
} // namespace internal
26+
1727
// Global storage for tensor metadata
1828
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
1929
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
2030

2131
extern "C" {
2232

33+
// Autograd mode functions
2334
int32_t aoti_torch_grad_mode_is_enabled() {
2435
// No autograd ever
2536
return false;
@@ -31,6 +42,7 @@ void aoti_torch_grad_mode_set_enabled(bool enabled) {
3142
}
3243
}
3344

45+
// Tensor attribute operations
3446
AOTITorchError aoti_torch_get_data_ptr(
3547
AOTITensorHandle tensor,
3648
void** ret_data_ptr) {
@@ -69,12 +81,6 @@ AOTITorchError aoti_torch_get_dtype(
6981
int32_t* ret_dtype) {
7082
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());
7183

72-
// ASSERTION: Only float32 tensors are supported
73-
AOTITorchError dtype_error = validate_dtype(*ret_dtype);
74-
if (dtype_error != Error::Ok) {
75-
return dtype_error;
76-
}
77-
7884
return Error::Ok;
7985
}
8086

@@ -100,13 +106,6 @@ AOTITorchError aoti_torch_get_storage_size(
100106
throw std::runtime_error("Cannot get storage size on ETensor");
101107
}
102108

103-
AOTITorchError aoti_torch_get_device_type(
104-
AOTITensorHandle tensor,
105-
int32_t* ret_device_type) {
106-
// All tensors in aoti-cuda delegate are on CUDA
107-
*ret_device_type = aoti_torch_device_type_cuda();
108-
return Error::Ok;
109-
}
110109

111110
AOTITorchError aoti_torch_get_device_index(
112111
AOTITensorHandle tensor,
@@ -121,6 +120,7 @@ AOTITorchError aoti_torch_get_dim(AOTITensorHandle tensor, int64_t* ret_dim) {
121120
return Error::Ok;
122121
}
123122

123+
// Device and layout utility functions
124124
int32_t aoti_torch_device_type_cpu() {
125125
// Let's say cpu is 0 for ET as well
126126
return 0;
@@ -132,60 +132,23 @@ __attribute__((__visibility__("default"))) int32_t aoti_torch_layout_strided() {
132132
return 0;
133133
}
134134

135-
__attribute__((__visibility__("default"))) int32_t
136-
aoti_torch_device_type_cuda() {
137-
// Let's say cuda is 1 for ET as well
138-
return 1;
139-
}
140-
141135
// Dtype constants - these return the PyTorch dtype codes
142136
// Currently only float32 is supported, but using robust enum-based approach
143137
__attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_float32() {
144-
return static_cast<int32_t>(SupportedDTypes::FLOAT32);
138+
return 6; // PyTorch's float32 dtype code
145139
}
146140

147-
// Future dtype support (commented out for now):
148-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_bool() {
149-
// return static_cast<int32_t>(SupportedDTypes::BOOL);
150-
// }
151-
//
152-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_uint8() {
153-
// return static_cast<int32_t>(SupportedDTypes::UINT8);
154-
// }
155-
//
156-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_int8() {
157-
// return static_cast<int32_t>(SupportedDTypes::INT8);
158-
// }
159-
//
160-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_int16() {
161-
// return static_cast<int32_t>(SupportedDTypes::INT16);
162-
// }
163-
//
164-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_int32() {
165-
// return static_cast<int32_t>(SupportedDTypes::INT32);
166-
// }
167-
//
168-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_int64() {
169-
// return static_cast<int32_t>(SupportedDTypes::INT64);
170-
// }
171-
//
172-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_float16() {
173-
// return static_cast<int32_t>(SupportedDTypes::FLOAT16);
174-
// }
175-
//
176-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_float64() {
177-
// return static_cast<int32_t>(SupportedDTypes::FLOAT64);
178-
// }
179-
//
180-
// __attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_bfloat16() {
181-
// return static_cast<int32_t>(SupportedDTypes::BFLOAT16);
182-
// }
183-
141+
// Cleanup functions
184142
void cleanup_tensor_metadata() {
185143
tensor_to_sizes.clear();
186144
tensor_to_strides.clear();
187145
}
188146

147+
void cleanup_aoti_tensor_output() {
148+
// Clean up any tensor output related resources
149+
// For now this is a no-op, but can be extended if needed
150+
}
151+
189152
} // extern "C"
190153

191154
} // namespace aoti

backends/aoti/runtime/shims/tensor_attribute.h renamed to backends/aoti/common_shims.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,30 @@
99
#pragma once
1010

1111
#include <cuda_runtime.h>
12+
#include <executorch/backends/aoti/utils.h>
13+
#include <executorch/extension/tensor/tensor.h>
14+
#include <executorch/runtime/core/error.h>
15+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
16+
#include <cstdint>
1217
#include <unordered_map>
1318
#include <vector>
14-
#include "types.h"
1519

1620
namespace executorch {
1721
namespace backends {
1822
namespace aoti {
1923

24+
// Common using declarations for ExecutorTorch types
25+
using executorch::runtime::Error;
26+
using executorch::runtime::etensor::Tensor;
27+
2028
extern "C" {
2129

30+
// Common AOTI type aliases
31+
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
32+
using AOTITensorHandle = Tensor*;
33+
using AOTIRuntimeError = Error;
34+
using AOTITorchError = Error;
35+
2236
// Global storage for tensor metadata
2337
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
2438
extern std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
@@ -48,10 +62,6 @@ AOTITorchError aoti_torch_get_storage_size(
4862
AOTITensorHandle tensor,
4963
int64_t* ret_size);
5064

51-
AOTITorchError aoti_torch_get_device_type(
52-
AOTITensorHandle tensor,
53-
int32_t* ret_device_type);
54-
5565
AOTITorchError aoti_torch_get_device_index(
5666
AOTITensorHandle tensor,
5767
int32_t* ret_device_index);
@@ -60,16 +70,16 @@ AOTITorchError aoti_torch_get_dim(AOTITensorHandle tensor, int64_t* ret_dim);
6070

6171
// Utility functions for device and layout information
6272
int32_t aoti_torch_device_type_cpu();
63-
int32_t aoti_torch_device_type_cuda();
6473
int32_t aoti_torch_layout_strided();
6574
int32_t aoti_torch_dtype_float32();
6675

6776
// Autograd mode functions
6877
int32_t aoti_torch_grad_mode_is_enabled();
6978
void aoti_torch_grad_mode_set_enabled(bool enabled);
7079

71-
// Cleanup function for clearing global state
80+
// Cleanup functions for clearing global state
7281
void cleanup_tensor_metadata();
82+
void cleanup_aoti_tensor_output();
7383

7484
} // extern "C"
7585

backends/aoti/cuda/CMakeLists.txt

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Build AOTI CUDA backend for runtime.
8+
#
9+
# ### Editing this file ###
10+
#
11+
# This file should be formatted with
12+
# ~~~
13+
# cmake-format -i CMakeLists.txt
14+
# ~~~
15+
# It should also be cmake-lint clean.
16+
#
17+
18+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
19+
20+
# Source root directory for executorch.
21+
if(NOT EXECUTORCH_ROOT)
22+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
23+
endif()
24+
25+
find_package(CUDAToolkit REQUIRED)
26+
27+
# Use ExecutorTorch's standard way to find PyTorch libraries for AOTI
28+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
29+
find_package_torch()
30+
31+
# CUDA-specific AOTI functionality
32+
set(_aoti_cuda_sources
33+
runtime/cuda_backend.cpp
34+
runtime/shims/memory.cpp
35+
runtime/shims/tensor_attribute.cpp
36+
runtime/utils.cpp)
37+
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
38+
target_include_directories(
39+
aoti_cuda
40+
PUBLIC
41+
${CUDAToolkit_INCLUDE_DIRS}
42+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
43+
$<INSTALL_INTERFACE:include>
44+
# PyTorch AOTI headers from ExecutorTorch's torch detection
45+
${TORCH_INCLUDE_DIRS}
46+
)
47+
target_compile_options(aoti_cuda PUBLIC -fexceptions -frtti -fPIC)
48+
# Ensure symbols are exported properly
49+
target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)
50+
51+
# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
52+
target_link_libraries(
53+
aoti_cuda
54+
PUBLIC
55+
aoti_common
56+
CUDA::cudart
57+
${CMAKE_DL_LIBS}
58+
# Link PyTorch libraries for AOTI CUDA functions
59+
${TORCH_LIBRARIES}
60+
)
61+
# If you need other CUDA libraries, link them similarly:
62+
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
63+
executorch_target_link_options_shared_lib(aoti_cuda)
64+
65+
66+
install(
67+
TARGETS aoti_cuda
68+
EXPORT ExecuTorchTargets
69+
DESTINATION lib
70+
)

backends/aoti/cuda/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()

0 commit comments

Comments
 (0)