Skip to content

Commit 9976b77

Browse files
Cherry-pick LibTorch Stable ABI documentation (pytorch#167112 pytorch#166661 pytorch#163899) (pytorch#167323)
* [BE] Refresh documentation for stable ABI / API (pytorch#163899) Pull Request resolved: pytorch#163899 Approved by: https://github.com/janeyx99 * Document LibTorch ABI more, add README to headeronly (pytorch#166661) Pull Request resolved: pytorch#166661 Approved by: https://github.com/mikaylagawarecki, https://github.com/albanD * Add guidance on how to migrate kernels to the libtorch stable ABI (pytorch#167112) Pull Request resolved: pytorch#167112 Approved by: https://github.com/janeyx99 --------- Co-authored-by: Jane Xu <[email protected]>
1 parent e6bcbbe commit 9976b77

File tree

2 files changed

+207
-9
lines changed

2 files changed

+207
-9
lines changed

docs/source/notes/libtorch_stable_abi.md

Lines changed: 200 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,173 @@
11
# LibTorch Stable ABI
22

3-
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
3+
## Overview
4+
5+
The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
6+
7+
The limited stable ABI consists of three main components:
8+
9+
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
10+
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
11+
3. **Stable C++ wrappers** - High-level C++ convenience wrappers (`torch/csrc/stable/*`)
12+
13+
We discuss each of these in detail
14+
15+
### `torch/headeronly`
16+
17+
The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
18+
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
19+
20+
### `torch/csrc/stable`
21+
22+
This is a set of inlined C++ headers that provide wrappers around the C API that handle the rough edges
23+
discussed below.
24+
25+
It consists of
26+
27+
- torch/csrc/stable/library.h: Provides a stable version of TORCH_LIBRARY and similar macros.
28+
- torch/csrc/stable/tensor_struct.h: Provides torch::stable::Tensor, a stable version of at::Tensor.
29+
- torch/csrc/stable/ops.h: Provides a stable interface for calling ATen ops from `native_functions.yaml`.
30+
- torch/csrc/stable/accelerator.h: Provides a stable interface for device-generic objects and APIs
31+
(e.g. `getCurrentStream`, `DeviceGuard`).
32+
33+
We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please file an issue if you'd like to see support for particular APIs in your custom extension.
34+
35+
### Stable C headers
36+
37+
The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include:
38+
39+
- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like.
40+
- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`).
41+
- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher.
42+
- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case.
43+
44+
These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called.
45+
46+
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
47+
which will handle all the rough edges of the C API for the user.
48+
49+
## Migrating your kernel to the LibTorch stable ABI
50+
51+
If you'd like your kernel to be ABI stable with LibTorch, meaning you'd the ability to build for one version and run on another, your kernel must only use the limited stable ABI. This following section goes through some steps of migrating an existing kernel and APIs we imagine you would need to swap over.
52+
53+
Firstly, instead of registering kernels through `TORCH_LIBRARY`, LibTorch ABI stable kernels must be registered via `STABLE_TORCH_LIBRARY`. Note that, for the time being, implementations registered via `STABLE_TORCH_LIBRARY` must be boxed unlike `TORCH_LIBRARY`. See the simple example below or our docs on [Stack-based APIs](stack-based-apis) for more details. For kernels that are registered via `pybind`, before using the stable ABI, it would be useful to migrate to register them via `TORCH_LIBRARY`.
54+
55+
While previously your kernels might have included APIs from `<torch/*.h>` (for example, `<torch/all.h>`), they are now limited to including from the 3 categories of headers mentioned above (`torch/csrc/stable/*.h`, `torch/headeronly/*.h` and the stable C headers). This means that your extension should no longer use any utilities from the `at::` or `c10::` namespaces but instead use their replacements in `torch::stable` and `torch::headeronly`. To provide a couple examples of the necessary migrations:
56+
- all uses of `at::Tensor` must be replaced with `torch::stable::Tensor`
57+
- all uses of `TORCH_CHECK` must be replaced with `STD_TORCH_CHECK`
58+
- all uses of `at::kCUDA` must be replaced with `torch::headeronly::kCUDA` etc.
59+
- native functions such as `at::pad` must be replaced with `torch::stable::pad`
60+
- native functions that are called as Tensor methods (e.g., `Tensor.pad`) must be replaced with the ATen variant through `torch::stable::pad`.
61+
62+
As mentioned above, the LibTorch stable ABI is still under development. If there is any API or feature you would like to see added to the stable ABI/`torch::headeronly`/`torch::stable`, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
63+
64+
Below is a simple example of migrating an existing kernel that uses `TORCH_LIBRARY` to the stable ABI (`TORCH_STABLE_LIBRARY`). For a larger end to end example you can take a look at the FA3 repository. Specifically the diff between [`flash_api.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api.cpp#L1) and the stable variant [`flash_api_stable.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api_stable.cpp#L1).
65+
66+
67+
### Original Version with `TORCH_LIBRARY`
68+
69+
```cpp
70+
// original_kernel.cpp - Using TORCH_LIBRARY (not stable ABI)
71+
#include <torch/torch.h>
72+
#include <ATen/ATen.h>
73+
74+
namespace myops {
75+
76+
// Simple kernel that adds a scalar value to each element of a tensor
77+
at::Tensor add_scalar(const at::Tensor& input, double scalar) {
78+
TORCH_CHECK(input.scalar_type() == at::kFloat, "Input must be float32");
79+
80+
return input.add(scalar);
81+
}
82+
83+
// Register the operator
84+
TORCH_LIBRARY(myops, m) {
85+
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &add_scalar);
86+
}
87+
88+
// Register the implementation
89+
TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
90+
m.impl("add_scalar", &add_scalar);
91+
}
92+
93+
} // namespace myops
94+
```
95+
96+
### Migrated Version with `STABLE_TORCH_LIBRARY`
97+
98+
```cpp
99+
// stable_kernel.cpp - Using STABLE_TORCH_LIBRARY (stable ABI)
100+
101+
// (1) Don't include <torch/torch.h> <ATen/ATen.h>
102+
// only include APIs from torch/csrc/stable, torch/headeronly and C-shims
103+
#include <torch/csrc/stable/library.h>
104+
#include <torch/csrc/stable/tensor_struct.h>
105+
#include <torch/csrc/stable/ops.h>
106+
#include <torch/csrc/stable/stableivalue_conversions.h>
107+
#include <torch/headeronly/core/ScalarType.h>
108+
#include <torch/headeronly/macros/Macros.h>
109+
110+
namespace myops {
111+
112+
// Simple kernel that adds a scalar value to each element of a tensor
113+
torch::stable::Tensor add_scalar(const torch::stable::Tensor& input, double scalar) {
114+
// (2) use STD_TORCH_CHECK instead of TORCH_CHECK
115+
STD_TORCH_CHECK(
116+
// (3) use torch::headeronly::kFloat instead of at:kFloat
117+
input.scalar_type() == torch::headeronly::kFloat,
118+
"Input must be float32");
119+
120+
// (4) Use stable ops namespace instead of input.add
121+
return torch::stable::add(input, scalar);
122+
}
123+
124+
// (5) Add Boxed wrapper required for STABLE_TORCH_LIBRARY
125+
void boxed_add_scalar(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
126+
// Extract arguments from stack using `to<T>`
127+
auto input = to<torch::stable::Tensor>(stack[0]);
128+
auto scalar = to<double>(stack[1]);
129+
130+
// Call the actual kernel
131+
auto result = add_scalar(input, scalar);
132+
133+
// Put result back on stack using `from()`
134+
// Stack slot 0 now holds the return value
135+
stack[0] = from(result);
136+
}
137+
138+
// (6) Register the operator using STABLE_TORCH_LIBRARY
139+
STABLE_TORCH_LIBRARY(myops, m) {
140+
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &boxed_add_scalar);
141+
}
142+
143+
// (7) Register the implementation using STABLE_TORCH_LIBRARY_IMPL
144+
STABLE_TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
145+
m.impl("add_scalar", &boxed_add_scalar);
146+
}
147+
148+
} // namespace myops
149+
```
150+
151+
152+
## How are objects passed across the ABI boundary when interacting with the dispatcher?
153+
154+
When interacting with the dispatcher via the stable APIs (``STABLE_TORCH_LIBRARY`` etc.) we use a boxed convention. Arguments and returns are represented as a stack of ``StableIValue`` which correlates with a `torch::jit::stack` of IValues. We discuss the following below
155+
1. StableIValue Conversions
156+
2. StableIValue stack Conventions
157+
3. Stable APIs that interact with the dispatcher
158+
159+
### StableIValue Conversions
160+
161+
We provide utilities for users to convert objects to and from StableIValues with the synonymous
162+
`to` and `from` APIs in `torch/csrc/stable/stableivalue_conversions.h`. We document the stable custom extension representation, libtorch representation and StableIValue
163+
representations below. Our confidently supported types are the ones in the table that have completed
164+
rows. You can rely on this subset for proper ABI stability, meaning that you can call `to<T_custom_ext>(arg/ret)` or `from(T)` on these types.
165+
166+
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
167+
168+
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
169+
170+
4171
1. type in custom extension: type used within the end user custom library.
5172
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
6173
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
@@ -31,16 +198,10 @@ This note will eventually contain more details on how to use the APIs in torch/c
31198
| ? | ? | c10::SymBool | SymBool |
32199
| ? | ? | at::QScheme | QScheme |
33200

34-
Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability.
35-
36-
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
37-
38-
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
39-
40201

41-
## How to use stack-based APIs
202+
### Stack Conventions
42203

43-
`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants:
204+
There are two invariants for the stack:
44205

45206
1. The stack is populated left to right.
46207
a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2.
@@ -49,3 +210,33 @@ You can always work with StableIValue abstractions in your custom kernel for typ
49210
2. The stack always has ownership of the objects it holds.
50211
a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack.
51212
b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references.
213+
214+
(stack-based-apis)=
215+
### Stack-based APIs
216+
217+
The above is relevant in two places:
218+
219+
1. `STABLE_TORCH_LIBRARY`
220+
Unlike `TORCH_LIBRARY`, the dispatcher expects kernels registered via `STABLE_TORCH_LIBRARY` to be boxed. This means they must have the signature `(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) -> void`.We plan to eventually abstract away the need for manual boxing, but, for the time being, please use `from` and `to`.
221+
222+
```cpp
223+
Tensor my_amax_vec(Tensor t) {
224+
std::vector<int64_t> v = {0,1};
225+
return amax(t, v, false);
226+
}
227+
228+
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
229+
auto res = my_amax_vec(to<Tensor>(stack[0]));
230+
stack[0] = from(res);
231+
}
232+
```
233+
234+
2. `aoti_torch_call_dispatcher`
235+
This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature:
236+
```cpp
237+
aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack);
238+
```
239+
240+
`aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of
241+
StableIValues. This call will populate any return values of the op into the stack in their StableIValue form,
242+
with `ret0` at index 0, `ret1` at index 1, and so on.

torch/headeronly/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## torch/headeronly
2+
3+
The inlined C++ headers in the `torch::headeronly` namespace living this subdirectory are completely decoupled from LibTorch. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt).
4+
5+
There are two types of LibTorch independent header-only headers:
6+
1. OG header-only. Originally header-only APIs, such as `ScalarType`, `Half`, `BFloat16`, have always been implemented in headers only. For them to move into torch/headeronly only required a code migration, a copy-pasta, if you will.
7+
2. Made to be header-only. There are also APIs that were NOT header-only that we made to be header-only. One example of such an API is `STD_TORCH_CHECK`, which was derived from `TORCH_CHECK`. `STD_TORCH_CHECK` calls into `std::runtime_error` instead of relying on `c10::Error`, which relies on libtorch.so. As a result, `STD_TORCH_CHECK` does not have the full `TORCH_CHECK` functionality that displays a fanciful traceback when the check is not met. We intentionally maintain the design that functions that do different things should be explicitly named differently.

0 commit comments

Comments
 (0)