Skip to content

Commit fabbda6

Browse files
authored
common library for et-aoti-driven operators
Differential Revision: D83003496 Pull Request resolved: #14492
1 parent 4affee3 commit fabbda6

File tree

12 files changed

+973
-0
lines changed

12 files changed

+973
-0
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Build AOTI backend for runtime.
8+
#
9+
# ### Editing this file ###
10+
#
11+
# This file should be formatted with
12+
# ~~~
13+
# cmake-format -i CMakeLists.txt
14+
# ~~~
15+
# It should also be cmake-lint clean.
16+
#
17+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
18+
19+
# Source root directory for executorch.
20+
if(NOT EXECUTORCH_ROOT)
21+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
22+
endif()
23+
24+
# Use ExecuTorch's standard way to find PyTorch libraries for AOTI
25+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
26+
find_package_torch()
27+
28+
# Common AOTI functionality - combines all AOTI common components
29+
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
30+
add_library(aoti_common STATIC ${_aoti_common_sources})
31+
target_include_directories(
32+
aoti_common
33+
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
34+
# PyTorch AOTI headers from ExecuTorch's torch detection
35+
${TORCH_INCLUDE_DIRS}
36+
)
37+
target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
38+
# Ensure symbols are exported properly
39+
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
40+
41+
# Link against PyTorch libraries and standard libraries
42+
target_link_libraries(
43+
aoti_common
44+
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
45+
# Link PyTorch libraries for AOTI functions
46+
${TORCH_LIBRARIES}
47+
)
48+
executorch_target_link_options_shared_lib(aoti_common)
49+
50+
install(
51+
TARGETS aoti_common
52+
EXPORT ExecuTorchTargets
53+
DESTINATION lib
54+
)

backends/aoti/README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# AOTI Common Library
2+
3+
This directory contains **common library components** for AOTI (Ahead-of-Time Inference) driven backends in ExecutorTorch, **not a standalone backend**.
4+
5+
## Purpose
6+
7+
The code in this directory provides shared functionality and utilities that are used by actual AOTI-driven backends such as:
8+
9+
- **CUDA backend** - Uses AOTI for GPU acceleration
10+
- Other AOTI-powered backends
11+
12+
## Components
13+
14+
- **`common_shims.cpp/h`** - Common shim functions that bridge ExecuTorch tensor operations with AOTI requirements
15+
- **`aoti_model_container.cpp/h`** - Model container functionality for AOTI models
16+
- **`utils.h`** - Utility functions and type definitions
17+
- **`tests/`** - Unit tests for the common functionality
18+
19+
## Usage
20+
21+
This library is intended to be used as a dependency by actual AOTI backend implementations. It is not a backend that can be used directly for model execution.
22+
23+
For example backend implementations that use this common library, see:
24+
- `executorch/backends/cuda/` - CUDA AOTI backend
25+
26+
## Building
27+
28+
The common library components are built as part of the AOTI backend build process. See the `TARGETS` file for build configurations.

backends/aoti/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/aoti_model_container.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace aoti {
14+
15+
extern "C" {
16+
17+
// Global function pointers for AOT Inductor model container operations
18+
// These will be loaded dynamically from the shared library
19+
AOTInductorModelContainerCreateWithDeviceFunc
20+
AOTInductorModelContainerCreateWithDevice = nullptr;
21+
AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete = nullptr;
22+
AOTInductorModelContainerGetNumInputsFunc
23+
AOTInductorModelContainerGetNumInputs = nullptr;
24+
AOTInductorModelContainerGetNumOutputsFunc
25+
AOTInductorModelContainerGetNumOutputs = nullptr;
26+
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
27+
28+
} // extern "C"
29+
30+
} // namespace aoti
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/evalue.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace aoti {
17+
18+
using executorch::runtime::Error;
19+
using executorch::runtime::etensor::Tensor;
20+
21+
extern "C" {
22+
23+
// Type definitions
24+
using AOTIRuntimeError = Error;
25+
26+
// Forward declarations for AOT Inductor model container
27+
struct AOTInductorModelContainerOpaque;
28+
using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
29+
using AOTInductorStreamHandle = void*;
30+
using AOTIProxyExecutorHandle = void*;
31+
32+
// Function pointer types for AOT Inductor model container operations
33+
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
34+
AOTInductorModelContainerHandle* container_handle,
35+
size_t num_models,
36+
const char* device_str,
37+
const char* cubin_dir);
38+
39+
using AOTInductorModelContainerDeleteFunc =
40+
AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle);
41+
42+
using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)(
43+
AOTInductorModelContainerHandle container_handle,
44+
size_t* num_inputs);
45+
46+
using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)(
47+
AOTInductorModelContainerHandle container_handle,
48+
size_t* num_outputs);
49+
50+
using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
51+
AOTInductorModelContainerHandle container_handle,
52+
Tensor** input_handles, // array of input Tensor*; handles
53+
// are stolen; the array itself is borrowed
54+
size_t num_inputs,
55+
Tensor** output_handles, // array for writing output Tensor*; handles
56+
// will be stolen by the caller; the array itself
57+
// is borrowed
58+
size_t n_outputs,
59+
AOTInductorStreamHandle stream_handle,
60+
AOTIProxyExecutorHandle proxy_executor_handle);
61+
62+
// Global function pointers (will be loaded dynamically)
63+
extern AOTInductorModelContainerCreateWithDeviceFunc
64+
AOTInductorModelContainerCreateWithDevice;
65+
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
66+
extern AOTInductorModelContainerGetNumInputsFunc
67+
AOTInductorModelContainerGetNumInputs;
68+
extern AOTInductorModelContainerGetNumOutputsFunc
69+
AOTInductorModelContainerGetNumOutputs;
70+
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
71+
72+
} // extern "C"
73+
74+
// AOTI Delegate Handle structure
75+
struct AOTIDelegateHandle {
76+
void* so_handle;
77+
AOTInductorModelContainerHandle container_handle;
78+
};
79+
80+
} // namespace aoti
81+
} // namespace backends
82+
} // namespace executorch

backends/aoti/common_shims.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace aoti {
16+
17+
namespace internal {
18+
// Global storage for tensor metadata
19+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
20+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
21+
} // namespace internal
22+
23+
extern "C" {
24+
25+
// Autograd mode functions
26+
int32_t aoti_torch_grad_mode_is_enabled() {
27+
// No autograd ever
28+
return false;
29+
}
30+
31+
void aoti_torch_grad_mode_set_enabled(bool enabled) {
32+
if (enabled) {
33+
throw std::runtime_error("Cannot enable autograd");
34+
}
35+
}
36+
37+
// Tensor attribute operations
38+
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
39+
*ret_data_ptr = tensor->mutable_data_ptr();
40+
return Error::Ok;
41+
}
42+
43+
AOTITorchError aoti_torch_get_storage_offset(
44+
Tensor* tensor,
45+
int64_t* ret_storage_offset) {
46+
// Storage offset is always 0 in ET
47+
*ret_storage_offset = 0;
48+
49+
return Error::Ok;
50+
}
51+
52+
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
53+
auto it = internal::tensor_to_strides.find(tensor);
54+
if (it == internal::tensor_to_strides.end()) {
55+
std::vector<int64_t> strides(tensor->dim());
56+
auto tensor_strides = tensor->strides();
57+
for (int i = 0; i < tensor->dim(); i++) {
58+
strides[i] = tensor_strides[i];
59+
}
60+
it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first;
61+
}
62+
63+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
64+
// return a valid pointer
65+
if (it->second.empty()) {
66+
static int64_t empty_strides_placeholder = 0;
67+
*ret_strides = &empty_strides_placeholder;
68+
} else {
69+
*ret_strides = it->second.data();
70+
}
71+
72+
return Error::Ok;
73+
}
74+
75+
AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
76+
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());
77+
78+
return Error::Ok;
79+
}
80+
81+
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
82+
auto it = internal::tensor_to_sizes.find(tensor);
83+
if (it == internal::tensor_to_sizes.end()) {
84+
std::vector<int64_t> sizes(tensor->dim());
85+
auto tensor_sizes = tensor->sizes();
86+
for (int i = 0; i < tensor->dim(); i++) {
87+
sizes[i] = tensor_sizes[i];
88+
}
89+
it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first;
90+
}
91+
92+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
93+
// return a valid pointer
94+
if (it->second.empty()) {
95+
static int64_t empty_sizes_placeholder = 0;
96+
*ret_sizes = &empty_sizes_placeholder;
97+
} else {
98+
*ret_sizes = it->second.data();
99+
}
100+
101+
return Error::Ok;
102+
}
103+
104+
AOTITorchError aoti_torch_get_device_index(
105+
Tensor* tensor,
106+
int32_t* ret_device_index) {
107+
// Let's assume all tensors AOTI using are on CUDA:0
108+
*ret_device_index = 0;
109+
return Error::Ok;
110+
}
111+
112+
AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
113+
*ret_dim = static_cast<int64_t>(tensor->dim());
114+
return Error::Ok;
115+
}
116+
117+
// Device and layout utility functions
118+
int32_t aoti_torch_device_type_cpu() {
119+
// Let's say cpu is 0 for ET as well
120+
return 0;
121+
}
122+
123+
int32_t aoti_torch_layout_strided() {
124+
// ET only support strided layout, the return value will always be 0, a.k.a
125+
// at::Layout::Strided;
126+
return 0;
127+
}
128+
129+
// Dtype constants - these return the PyTorch dtype codes
130+
// Currently only float32 is supported, but using robust enum-based approach
131+
int32_t aoti_torch_dtype_float32() {
132+
return 6; // PyTorch's float32 dtype code
133+
}
134+
135+
// Cleanup functions
136+
void cleanup_tensor_metadata() {
137+
internal::tensor_to_sizes.clear();
138+
internal::tensor_to_strides.clear();
139+
}
140+
141+
} // extern "C"
142+
143+
} // namespace aoti
144+
} // namespace backends
145+
} // namespace executorch

0 commit comments

Comments
 (0)