Skip to content

Commit eae701c

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add scaffolding for StableIValue FC/BC (no PoC) (pytorch#164332)
1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) pytorch#163832 has the PoC of how we would actually use this system 2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument 3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites 4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land **Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in pytorch#163998 Pull Request resolved: pytorch#164332 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#164356, pytorch#166373, pytorch#163683
1 parent 8f51556 commit eae701c

File tree

3 files changed

+188
-54
lines changed

3 files changed

+188
-54
lines changed

torch/csrc/shim_common.cpp

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,41 @@
1212

1313
static StableIValue from_ivalue(
1414
const c10::TypePtr& type,
15-
const c10::IValue& ivalue) {
15+
const c10::IValue& ivalue,
16+
uint64_t extension_build_version) {
1617
switch (type->kind()) {
1718
case c10::TypeKind::TensorType: {
1819
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
1920
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
20-
return torch::stable::detail::from(ath);
21+
return torch::stable::detail::_from(ath, extension_build_version);
2122
}
2223
case c10::TypeKind::IntType: {
23-
return torch::stable::detail::from(ivalue.toInt());
24+
return torch::stable::detail::_from(
25+
ivalue.toInt(), extension_build_version);
2426
}
2527
case c10::TypeKind::FloatType: {
26-
return torch::stable::detail::from(ivalue.toDouble());
28+
return torch::stable::detail::_from(
29+
ivalue.toDouble(), extension_build_version);
2730
}
2831
case c10::TypeKind::BoolType: {
29-
return torch::stable::detail::from(ivalue.toBool());
32+
return torch::stable::detail::_from(
33+
ivalue.toBool(), extension_build_version);
3034
}
3135
case c10::TypeKind::ScalarTypeType: {
32-
return torch::stable::detail::from(ivalue.toScalarType());
36+
return torch::stable::detail::_from(
37+
ivalue.toScalarType(), extension_build_version);
3338
}
3439
case c10::TypeKind::DeviceObjType: {
35-
return torch::stable::detail::from(ivalue.toDevice());
40+
return torch::stable::detail::_from(
41+
ivalue.toDevice(), extension_build_version);
3642
}
3743
case c10::TypeKind::LayoutType: {
38-
return torch::stable::detail::from(ivalue.toLayout());
44+
return torch::stable::detail::_from(
45+
ivalue.toLayout(), extension_build_version);
3946
}
4047
case c10::TypeKind::MemoryFormatType: {
41-
return torch::stable::detail::from(ivalue.toMemoryFormat());
48+
return torch::stable::detail::_from(
49+
ivalue.toMemoryFormat(), extension_build_version);
4250
}
4351
case c10::TypeKind::OptionalType: {
4452
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@@ -56,10 +64,12 @@ static StableIValue from_ivalue(
5664
// be kept in sync with torch::stable::detail::from<std::optional<T>>
5765
// function in torch/csrc/stable/stableivalue_conversions.h
5866
if (ivalue.isNone()) {
59-
return torch::stable::detail::from(std::nullopt);
67+
return torch::stable::detail::_from(
68+
std::nullopt, extension_build_version);
6069
}
61-
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
62-
return torch::stable::detail::from(sivp);
70+
StableIValue* sivp = new StableIValue(
71+
from_ivalue(inner_type, ivalue, extension_build_version));
72+
return torch::stable::detail::_from(sivp, extension_build_version);
6373
}
6474
default: {
6575
TORCH_CHECK(
@@ -72,36 +82,43 @@ static StableIValue from_ivalue(
7282

7383
static c10::IValue to_ivalue(
7484
const c10::TypePtr& type,
75-
const StableIValue stable_ivalue) {
85+
const StableIValue stable_ivalue,
86+
uint64_t extension_build_version) {
7687
switch (type->kind()) {
7788
case c10::TypeKind::TensorType: {
7889
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
79-
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
90+
torch::stable::detail::_to<AtenTensorHandle>(
91+
stable_ivalue, extension_build_version));
8092
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
8193
ret_raiiath.get())));
8294
}
8395
case c10::TypeKind::IntType: {
84-
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
96+
return c10::IValue(torch::stable::detail::_to<int64_t>(
97+
stable_ivalue, extension_build_version));
8598
}
8699
case c10::TypeKind::FloatType: {
87-
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
100+
return c10::IValue(torch::stable::detail::_to<double>(
101+
stable_ivalue, extension_build_version));
88102
}
89103
case c10::TypeKind::BoolType: {
90-
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
104+
return c10::IValue(torch::stable::detail::_to<bool>(
105+
stable_ivalue, extension_build_version));
91106
}
92107
case c10::TypeKind::ScalarTypeType: {
93-
return c10::IValue(
94-
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
108+
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
109+
stable_ivalue, extension_build_version));
95110
}
96111
case c10::TypeKind::DeviceObjType: {
97-
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
112+
return c10::IValue(torch::stable::detail::_to<c10::Device>(
113+
stable_ivalue, extension_build_version));
98114
}
99115
case c10::TypeKind::LayoutType: {
100-
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
116+
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
117+
stable_ivalue, extension_build_version));
101118
}
102119
case c10::TypeKind::MemoryFormatType: {
103-
return c10::IValue(
104-
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
120+
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
121+
stable_ivalue, extension_build_version));
105122
}
106123
case c10::TypeKind::OptionalType: {
107124
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
@@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
116133
//
117134
// BUT we do NOT have that type inner_type::t readily available, so we
118135
// will manually unwrap and recursively call. This implementation MUST
119-
// be kept in sync with the torch::stable::detail::to<T> function in
120-
// torch/csrc/stable/stableivalue_conversions.h
121-
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
136+
// be kept in sync with the torch::stable::detail::_to<T> function in
137+
// torch/csrc/stable/library.h
138+
if (stable_ivalue ==
139+
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
122140
return c10::IValue();
123141
}
124-
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
125-
auto ival = to_ivalue(inner_type, *sivp);
142+
auto sivp = torch::stable::detail::_to<StableIValue*>(
143+
stable_ivalue, extension_build_version);
144+
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
126145
delete sivp;
127146
return ival;
128147
}
@@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
137156

138157
class StableIValueBoxedKernel : public c10::OperatorKernel {
139158
public:
140-
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
141-
: fn_(fn) {}
159+
StableIValueBoxedKernel(
160+
void (*fn)(StableIValue*, uint64_t, uint64_t),
161+
uint64_t extension_build_version)
162+
: fn_(fn), extension_build_version_(extension_build_version) {}
142163

143164
void operator()(
144165
const c10::OperatorHandle& op,
@@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
154175
for (const auto idx : c10::irange(num_arguments)) {
155176
const auto ministack_idx = num_arguments - idx - 1;
156177
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
157-
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
178+
ministack[ministack_idx] = from_ivalue(
179+
arg_type, torch::jit::pop(stack), extension_build_version_);
158180
}
159181

160182
// boxed function is going to take a stack of StableIValues, cast them to
@@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
165187
// IValue from StableIValue
166188
for (size_t idx = 0; idx < num_returns; idx++) {
167189
const c10::TypePtr& ret_type = schema.returns()[idx].type();
168-
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
190+
torch::jit::push(
191+
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
169192
}
170193
}
171194

172195
private:
173196
void (*fn_)(StableIValue*, uint64_t, uint64_t);
197+
uint64_t extension_build_version_;
174198
};
175199

176200
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
@@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
181205
reinterpret_cast<torch::Library*>(self)->impl(
182206
name,
183207
torch::CppFunction::makeFromBoxedFunctor(
184-
std::make_unique<StableIValueBoxedKernel>(fn)));
208+
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
209+
});
210+
}
211+
212+
// Version-aware variant of aoti_torch_library_impl that takes an
213+
// extension_build_version parameter for backward compatibility
214+
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
215+
TorchLibraryHandle self,
216+
const char* name,
217+
void (*fn)(StableIValue*, uint64_t, uint64_t),
218+
uint64_t extension_build_version) {
219+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
220+
reinterpret_cast<torch::Library*>(self)->impl(
221+
name,
222+
torch::CppFunction::makeFromBoxedFunctor(
223+
std::make_unique<StableIValueBoxedKernel>(
224+
fn, extension_build_version)));
185225
});
186226
}
187227

@@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
204244
for (const auto idx : c10::irange(num_arguments)) {
205245
auto stable_ivalue = stack[idx];
206246
auto arg_type = schema.arguments()[idx].type();
207-
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
247+
torch::jit::push(
248+
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
208249
}
209250

210251
op.callBoxed(ivalue_stack);
@@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
214255
for (const auto idx : c10::irange(num_returns)) {
215256
const auto stack_idx = num_returns - idx - 1;
216257
const c10::TypePtr& ret_type = schema.returns()[idx].type();
217-
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
258+
stack[stack_idx] = from_ivalue(
259+
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
218260
}
219261
});
220262
}
@@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
355397
for (const auto idx : c10::irange(num_arguments)) {
356398
auto stable_ivalue = stack[idx];
357399
auto arg_type = schema.arguments()[idx].type();
358-
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
400+
torch::jit::push(
401+
ivalue_stack,
402+
to_ivalue(arg_type, stable_ivalue, extension_build_version));
359403
}
360404
}
361405

@@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
366410
for (const auto idx : c10::irange(num_returns)) {
367411
const auto stack_idx = num_returns - idx - 1;
368412
const c10::TypePtr& ret_type = schema.returns()[idx].type();
369-
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
413+
stack[stack_idx] = from_ivalue(
414+
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
370415
}
371416
});
372417
}

torch/csrc/stable/library.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
// code for better UX.
55

66
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
7+
#include <torch/csrc/stable/c/shim.h>
78
#include <torch/headeronly/macros/Macros.h>
89

910
// Technically, this file doesn't use anything from stableivalue_conversions.h,
1011
// but we need to include it here as the contents of stableivalue_conversions.h
1112
// used to live here and so we need to expose them for backwards compatibility.
1213
#include <torch/csrc/stable/stableivalue_conversions.h>
14+
#include <torch/csrc/stable/version.h>
1315

1416
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
1517

@@ -81,7 +83,11 @@ class StableLibrary final {
8183
StableLibrary& impl(
8284
const char* name,
8385
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
86+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
87+
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
88+
#else
8489
aoti_torch_library_impl(lib_, name, fn);
90+
#endif
8591
return *this;
8692
}
8793

0 commit comments

Comments
 (0)