Skip to content

Commit d597072

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
common library for et-aoti-driven operators
Summary: This diff introduce common functions for all aoti-driven backends under executorch like cuda and mps. It contain two major function families: container functions for holding and running aoti programs,. and common shim layers for aoti-lib. Worth to note that functions living here should be backend-agnostic. For backend-specific functions please make it live inside each backend directory. Differential Revision: D83003496
1 parent d610fdf commit d597072

File tree

12 files changed

+1003
-0
lines changed

12 files changed

+1003
-0
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Build AOTI backend for runtime.
8+
#
9+
# ### Editing this file ###
10+
#
11+
# This file should be formatted with
12+
# ~~~
13+
# cmake-format -i CMakeLists.txt
14+
# ~~~
15+
# It should also be cmake-lint clean.
16+
#
17+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
18+
19+
# Source root directory for executorch.
20+
if(NOT EXECUTORCH_ROOT)
21+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
22+
endif()
23+
24+
# Use ExecuTorch's standard way to find PyTorch libraries for AOTI
25+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
26+
find_package_torch()
27+
28+
# Common AOTI functionality - combines all AOTI common components
29+
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
30+
add_library(aoti_common STATIC ${_aoti_common_sources})
31+
target_include_directories(
32+
aoti_common
33+
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
34+
# PyTorch AOTI headers from ExecuTorch's torch detection
35+
${TORCH_INCLUDE_DIRS}
36+
)
37+
target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
38+
# Ensure symbols are exported properly
39+
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
40+
41+
# Link against PyTorch libraries and standard libraries
42+
target_link_libraries(
43+
aoti_common
44+
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
45+
# Link PyTorch libraries for AOTI functions
46+
${TORCH_LIBRARIES}
47+
)
48+
executorch_target_link_options_shared_lib(aoti_common)
49+
50+
install(
51+
TARGETS aoti_common
52+
EXPORT ExecuTorchTargets
53+
DESTINATION lib
54+
)

backends/aoti/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# AOTI Common Library
2+
3+
This directory contains **common library components** for AOTI (Ahead-of-Time Inference) driven backends in ExecutorTorch, **not a standalone backend**.
4+
5+
## Purpose
6+
7+
The code in this directory provides shared functionality and utilities that are used by actual AOTI-driven backends such as:
8+
9+
- **CUDA backend** - Uses AOTI for GPU acceleration
10+
- Other AOTI-powered backends
11+
12+
## Components
13+
14+
- **`common_shims.cpp/h`** - Common shim functions that bridge ExecuTorch tensor operations with AOTI requirements
15+
- **`aoti_model_container.cpp/h`** - Model container functionality for AOTI models
16+
- **`utils.h`** - Utility functions and type definitions
17+
- **`tests/`** - Unit tests for the common functionality
18+
19+
## Usage
20+
21+
This library is intended to be used as a dependency by actual AOTI backend implementations. It is not a backend that can be used directly for model execution.
22+
23+
For example backend implementations that use this common library, see:
24+
- `executorch/backends/cuda/` - CUDA AOTI backend
25+
26+
## Building
27+
28+
The common library components are built as part of the AOTI backend build process. See the `TARGETS` file for build configurations.

backends/aoti/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/aoti_model_container.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace aoti {
14+
15+
extern "C" {
16+
17+
// Global function pointers for AOT Inductor model container operations
18+
// These will be loaded dynamically from the shared library
19+
AOTInductorModelContainerCreateWithDeviceFunc
20+
AOTInductorModelContainerCreateWithDevice = nullptr;
21+
AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr;
22+
AOTInductorModelContainerGetNumInputsFunc
23+
AOTInductorModelContainerGetNumInputs = nullptr;
24+
AOTInductorModelContainerGetNumOutputsFunc
25+
AOTInductorModelContainerGetNumOutputs = nullptr;
26+
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
27+
28+
} // extern "C"
29+
30+
} // namespace aoti
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace aoti {
17+
18+
using executorch::runtime::Error;
19+
using executorch::runtime::etensor::Tensor;
20+
21+
extern "C" {
22+
23+
// Type definitions
24+
using AOTITensorHandle = Tensor*;
25+
using AOTIRuntimeError = Error;
26+
27+
// Forward declarations for AOT Inductor model container
28+
struct AOTInductorModelContainerOpaque;
29+
using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
30+
using AOTInductorStreamHandle = void*;
31+
using AOTIProxyExecutorHandle = void*;
32+
33+
// Function pointer types for AOT Inductor model container operations
34+
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
35+
AOTInductorModelContainerHandle* container_handle,
36+
size_t num_models,
37+
const char* device_str,
38+
const char* cubin_dir);
39+
40+
using AOTInductorModelContainerDeleteFunc =
41+
AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle);
42+
43+
using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)(
44+
AOTInductorModelContainerHandle container_handle,
45+
size_t* num_inputs);
46+
47+
using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)(
48+
AOTInductorModelContainerHandle container_handle,
49+
size_t* num_outputs);
50+
51+
using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
52+
AOTInductorModelContainerHandle container_handle,
53+
AOTITensorHandle* input_handles, // array of input AOTITensorHandle; handles
54+
// are stolen; the array itself is borrowed
55+
size_t num_inputs,
56+
AOTITensorHandle*
57+
output_handles, // array for writing output AOTITensorHandle; handles
58+
// will be stolen by the caller; the array itself is
59+
// borrowed
60+
size_t n_outputs,
61+
AOTInductorStreamHandle stream_handle,
62+
AOTIProxyExecutorHandle proxy_executor_handle);
63+
64+
// Global function pointers (will be loaded dynamically)
65+
extern AOTInductorModelContainerCreateWithDeviceFunc
66+
AOTInductorModelContainerCreateWithDevice;
67+
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
68+
extern AOTInductorModelContainerGetNumInputsFunc
69+
AOTInductorModelContainerGetNumInputs;
70+
extern AOTInductorModelContainerGetNumOutputsFunc
71+
AOTInductorModelContainerGetNumOutputs;
72+
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
73+
74+
} // extern "C"
75+
76+
// AOTI Delegate Handle structure
77+
struct AOTIDelegateHandle {
78+
void* so_handle;
79+
AOTInductorModelContainerHandle container_handle;
80+
};
81+
82+
} // namespace aoti
83+
} // namespace backends
84+
} // namespace executorch

backends/aoti/common_shims.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace aoti {
16+
17+
namespace internal {
18+
19+
// Global storage for tensor metadata
20+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
21+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
22+
} // namespace internal
23+
24+
extern "C" {
25+
26+
// Autograd mode functions
27+
int32_t aoti_torch_grad_mode_is_enabled() {
28+
// No autograd ever
29+
return false;
30+
}
31+
32+
void aoti_torch_grad_mode_set_enabled(bool enabled) {
33+
if (enabled) {
34+
throw std::runtime_error("Cannot enable autograd");
35+
}
36+
}
37+
38+
// Tensor attribute operations
39+
AOTITorchError aoti_torch_get_data_ptr(
40+
AOTITensorHandle tensor,
41+
void** ret_data_ptr) {
42+
*ret_data_ptr = tensor->mutable_data_ptr();
43+
return Error::Ok;
44+
}
45+
46+
AOTITorchError aoti_torch_get_storage_offset(
47+
AOTITensorHandle tensor,
48+
int64_t* ret_storage_offset) {
49+
// Storage offset is always 0 in ET
50+
*ret_storage_offset = 0;
51+
52+
return Error::Ok;
53+
}
54+
55+
AOTITorchError aoti_torch_get_strides(
56+
AOTITensorHandle tensor,
57+
int64_t** ret_strides) {
58+
auto it = internal::tensor_to_strides.find(tensor);
59+
if (it == internal::tensor_to_strides.end()) {
60+
std::vector<int64_t> strides(tensor->dim());
61+
auto tensor_strides = tensor->strides();
62+
for (int i = 0; i < tensor->dim(); i++) {
63+
strides[i] = tensor_strides[i];
64+
}
65+
it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first;
66+
}
67+
68+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
69+
// return a valid pointer
70+
if (it->second.empty()) {
71+
static int64_t empty_strides_placeholder = 0;
72+
*ret_strides = &empty_strides_placeholder;
73+
} else {
74+
*ret_strides = it->second.data();
75+
}
76+
77+
return Error::Ok;
78+
}
79+
80+
AOTITorchError aoti_torch_get_dtype(
81+
AOTITensorHandle tensor,
82+
int32_t* ret_dtype) {
83+
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());
84+
85+
return Error::Ok;
86+
}
87+
88+
AOTITorchError aoti_torch_get_sizes(
89+
AOTITensorHandle tensor,
90+
int64_t** ret_sizes) {
91+
auto it = internal::tensor_to_sizes.find(tensor);
92+
if (it == internal::tensor_to_sizes.end()) {
93+
std::vector<int64_t> sizes(tensor->dim());
94+
auto tensor_sizes = tensor->sizes();
95+
for (int i = 0; i < tensor->dim(); i++) {
96+
sizes[i] = tensor_sizes[i];
97+
}
98+
it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first;
99+
}
100+
101+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
102+
// return a valid pointer
103+
if (it->second.empty()) {
104+
static int64_t empty_sizes_placeholder = 0;
105+
*ret_sizes = &empty_sizes_placeholder;
106+
} else {
107+
*ret_sizes = it->second.data();
108+
}
109+
110+
return Error::Ok;
111+
}
112+
113+
AOTITorchError aoti_torch_get_storage_size(
114+
AOTITensorHandle tensor,
115+
int64_t* ret_size) {
116+
return Error::NotSupported;
117+
}
118+
119+
AOTITorchError aoti_torch_get_device_index(
120+
AOTITensorHandle tensor,
121+
int32_t* ret_device_index) {
122+
// Let's assume all tensors AOTI using are on CUDA:0
123+
*ret_device_index = 0;
124+
return Error::Ok;
125+
}
126+
127+
AOTITorchError aoti_torch_get_dim(AOTITensorHandle tensor, int64_t* ret_dim) {
128+
*ret_dim = static_cast<int64_t>(tensor->dim());
129+
return Error::Ok;
130+
}
131+
132+
// Device and layout utility functions
133+
int32_t aoti_torch_device_type_cpu() {
134+
// Let's say cpu is 0 for ET as well
135+
return 0;
136+
}
137+
138+
__attribute__((__visibility__("default"))) int32_t aoti_torch_layout_strided() {
139+
// ET only support strided layout, the return value will always be 0, a.k.a
140+
// at::Layout::Strided;
141+
return 0;
142+
}
143+
144+
// Dtype constants - these return the PyTorch dtype codes
145+
// Currently only float32 is supported, but using robust enum-based approach
146+
__attribute__((__visibility__("default"))) int32_t aoti_torch_dtype_float32() {
147+
return 6; // PyTorch's float32 dtype code
148+
}
149+
150+
// Cleanup functions
151+
void cleanup_tensor_metadata() {
152+
internal::tensor_to_sizes.clear();
153+
internal::tensor_to_strides.clear();
154+
}
155+
156+
} // extern "C"
157+
158+
} // namespace aoti
159+
} // namespace backends
160+
} // namespace executorch

0 commit comments

Comments
 (0)