1212
1313static StableIValue from_ivalue (
1414 const c10::TypePtr& type,
15- const c10::IValue& ivalue) {
15+ const c10::IValue& ivalue,
16+ uint64_t extension_build_version) {
1617 switch (type->kind ()) {
1718 case c10::TypeKind::TensorType: {
1819 AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle (
1920 std::move (const_cast <at::Tensor&>(ivalue.toTensor ())));
20- return torch::stable::detail::from (ath);
21+ return torch::stable::detail::_from (ath, extension_build_version );
2122 }
2223 case c10::TypeKind::IntType: {
23- return torch::stable::detail::from (ivalue.toInt ());
24+ return torch::stable::detail::_from (
25+ ivalue.toInt (), extension_build_version);
2426 }
2527 case c10::TypeKind::FloatType: {
26- return torch::stable::detail::from (ivalue.toDouble ());
28+ return torch::stable::detail::_from (
29+ ivalue.toDouble (), extension_build_version);
2730 }
2831 case c10::TypeKind::BoolType: {
29- return torch::stable::detail::from (ivalue.toBool ());
32+ return torch::stable::detail::_from (
33+ ivalue.toBool (), extension_build_version);
3034 }
3135 case c10::TypeKind::ScalarTypeType: {
32- return torch::stable::detail::from (ivalue.toScalarType ());
36+ return torch::stable::detail::_from (
37+ ivalue.toScalarType (), extension_build_version);
3338 }
3439 case c10::TypeKind::DeviceObjType: {
35- return torch::stable::detail::from (ivalue.toDevice ());
40+ return torch::stable::detail::_from (
41+ ivalue.toDevice (), extension_build_version);
3642 }
3743 case c10::TypeKind::LayoutType: {
38- return torch::stable::detail::from (ivalue.toLayout ());
44+ return torch::stable::detail::_from (
45+ ivalue.toLayout (), extension_build_version);
3946 }
4047 case c10::TypeKind::MemoryFormatType: {
41- return torch::stable::detail::from (ivalue.toMemoryFormat ());
48+ return torch::stable::detail::_from (
49+ ivalue.toMemoryFormat (), extension_build_version);
4250 }
4351 case c10::TypeKind::OptionalType: {
4452 auto inner_type = type->castRaw <at::OptionalType>()->getElementType ();
@@ -56,10 +64,12 @@ static StableIValue from_ivalue(
5664 // be kept in sync with torch::stable::detail::from<std::optional<T>>
5765 // function in torch/csrc/stable/stableivalue_conversions.h
5866 if (ivalue.isNone ()) {
59- return torch::stable::detail::from (std::nullopt );
67+ return torch::stable::detail::_from (
68+ std::nullopt , extension_build_version);
6069 }
61- StableIValue* sivp = new StableIValue (from_ivalue (inner_type, ivalue));
62- return torch::stable::detail::from (sivp);
70+ StableIValue* sivp = new StableIValue (
71+ from_ivalue (inner_type, ivalue, extension_build_version));
72+ return torch::stable::detail::_from (sivp, extension_build_version);
6373 }
6474 default : {
6575 TORCH_CHECK (
@@ -72,36 +82,43 @@ static StableIValue from_ivalue(
7282
7383static c10::IValue to_ivalue (
7484 const c10::TypePtr& type,
75- const StableIValue stable_ivalue) {
85+ const StableIValue stable_ivalue,
86+ uint64_t extension_build_version) {
7687 switch (type->kind ()) {
7788 case c10::TypeKind::TensorType: {
7889 auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle (
79- torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
90+ torch::stable::detail::_to<AtenTensorHandle>(
91+ stable_ivalue, extension_build_version));
8092 return (c10::IValue (*torch::aot_inductor::tensor_handle_to_tensor_pointer (
8193 ret_raiiath.get ())));
8294 }
8395 case c10::TypeKind::IntType: {
84- return c10::IValue (torch::stable::detail::to<int64_t >(stable_ivalue));
96+ return c10::IValue (torch::stable::detail::_to<int64_t >(
97+ stable_ivalue, extension_build_version));
8598 }
8699 case c10::TypeKind::FloatType: {
87- return c10::IValue (torch::stable::detail::to<double >(stable_ivalue));
100+ return c10::IValue (torch::stable::detail::_to<double >(
101+ stable_ivalue, extension_build_version));
88102 }
89103 case c10::TypeKind::BoolType: {
90- return c10::IValue (torch::stable::detail::to<bool >(stable_ivalue));
104+ return c10::IValue (torch::stable::detail::_to<bool >(
105+ stable_ivalue, extension_build_version));
91106 }
92107 case c10::TypeKind::ScalarTypeType: {
93- return c10::IValue (
94- torch::stable::detail::to<c10::ScalarType>( stable_ivalue));
108+ return c10::IValue (torch::stable::detail::_to<c10::ScalarType>(
109+ stable_ivalue, extension_build_version ));
95110 }
96111 case c10::TypeKind::DeviceObjType: {
97- return c10::IValue (torch::stable::detail::to<c10::Device>(stable_ivalue));
112+ return c10::IValue (torch::stable::detail::_to<c10::Device>(
113+ stable_ivalue, extension_build_version));
98114 }
99115 case c10::TypeKind::LayoutType: {
100- return c10::IValue (torch::stable::detail::to<c10::Layout>(stable_ivalue));
116+ return c10::IValue (torch::stable::detail::_to<c10::Layout>(
117+ stable_ivalue, extension_build_version));
101118 }
102119 case c10::TypeKind::MemoryFormatType: {
103- return c10::IValue (
104- torch::stable::detail::to<c10::MemoryFormat>( stable_ivalue));
120+ return c10::IValue (torch::stable::detail::_to<c10::MemoryFormat>(
121+ stable_ivalue, extension_build_version ));
105122 }
106123 case c10::TypeKind::OptionalType: {
107124 auto inner_type = type->castRaw <at::OptionalType>()->getElementType ();
@@ -116,13 +133,15 @@ static c10::IValue to_ivalue(
116133 //
117134 // BUT we do NOT have that type inner_type::t readily available, so we
118135 // will manually unwrap and recursively call. This implementation MUST
119- // be kept in sync with the torch::stable::detail::to<T> function in
120- // torch/csrc/stable/stableivalue_conversions.h
121- if (stable_ivalue == torch::stable::detail::from (std::nullopt )) {
136+ // be kept in sync with the torch::stable::detail::_to<T> function in
137+ // torch/csrc/stable/library.h
138+ if (stable_ivalue ==
139+ torch::stable::detail::_from (std::nullopt , extension_build_version)) {
122140 return c10::IValue ();
123141 }
124- auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
125- auto ival = to_ivalue (inner_type, *sivp);
142+ auto sivp = torch::stable::detail::_to<StableIValue*>(
143+ stable_ivalue, extension_build_version);
144+ auto ival = to_ivalue (inner_type, *sivp, extension_build_version);
126145 delete sivp;
127146 return ival;
128147 }
@@ -137,8 +156,10 @@ static c10::IValue to_ivalue(
137156
138157class StableIValueBoxedKernel : public c10 ::OperatorKernel {
139158 public:
140- StableIValueBoxedKernel (void (*fn)(StableIValue*, uint64_t , uint64_t ))
141- : fn_(fn) {}
159+ StableIValueBoxedKernel (
160+ void (*fn)(StableIValue*, uint64_t , uint64_t ),
161+ uint64_t extension_build_version)
162+ : fn_(fn), extension_build_version_(extension_build_version) {}
142163
143164 void operator ()(
144165 const c10::OperatorHandle& op,
@@ -154,7 +175,8 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
154175 for (const auto idx : c10::irange (num_arguments)) {
155176 const auto ministack_idx = num_arguments - idx - 1 ;
156177 const c10::TypePtr& arg_type = schema.arguments ()[ministack_idx].type ();
157- ministack[ministack_idx] = from_ivalue (arg_type, torch::jit::pop (stack));
178+ ministack[ministack_idx] = from_ivalue (
179+ arg_type, torch::jit::pop (stack), extension_build_version_);
158180 }
159181
160182 // boxed function is going to take a stack of StableIValues, cast them to
@@ -165,12 +187,14 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
165187 // IValue from StableIValue
166188 for (size_t idx = 0 ; idx < num_returns; idx++) {
167189 const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
168- torch::jit::push (stack, to_ivalue (ret_type, ministack[idx]));
190+ torch::jit::push (
191+ stack, to_ivalue (ret_type, ministack[idx], extension_build_version_));
169192 }
170193 }
171194
172195 private:
173196 void (*fn_)(StableIValue*, uint64_t , uint64_t );
197+ uint64_t extension_build_version_;
174198};
175199
176200AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl (
@@ -181,7 +205,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
181205 reinterpret_cast <torch::Library*>(self)->impl (
182206 name,
183207 torch::CppFunction::makeFromBoxedFunctor (
184- std::make_unique<StableIValueBoxedKernel>(fn)));
208+ std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
209+ });
210+ }
211+
212+ // Version-aware variant of aoti_torch_library_impl that takes an
213+ // extension_build_version parameter for backward compatibility
214+ AOTI_TORCH_EXPORT AOTITorchError torch_library_impl (
215+ TorchLibraryHandle self,
216+ const char * name,
217+ void (*fn)(StableIValue*, uint64_t , uint64_t ),
218+ uint64_t extension_build_version) {
219+ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
220+ reinterpret_cast <torch::Library*>(self)->impl (
221+ name,
222+ torch::CppFunction::makeFromBoxedFunctor (
223+ std::make_unique<StableIValueBoxedKernel>(
224+ fn, extension_build_version)));
185225 });
186226}
187227
@@ -204,7 +244,8 @@ AOTITorchError aoti_torch_call_dispatcher(
204244 for (const auto idx : c10::irange (num_arguments)) {
205245 auto stable_ivalue = stack[idx];
206246 auto arg_type = schema.arguments ()[idx].type ();
207- torch::jit::push (ivalue_stack, to_ivalue (arg_type, stable_ivalue));
247+ torch::jit::push (
248+ ivalue_stack, to_ivalue (arg_type, stable_ivalue, TORCH_ABI_VERSION));
208249 }
209250
210251 op.callBoxed (ivalue_stack);
@@ -214,7 +255,8 @@ AOTITorchError aoti_torch_call_dispatcher(
214255 for (const auto idx : c10::irange (num_returns)) {
215256 const auto stack_idx = num_returns - idx - 1 ;
216257 const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
217- stack[stack_idx] = from_ivalue (ret_type, torch::jit::pop (ivalue_stack));
258+ stack[stack_idx] = from_ivalue (
259+ ret_type, torch::jit::pop (ivalue_stack), TORCH_ABI_VERSION);
218260 }
219261 });
220262}
@@ -355,7 +397,9 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
355397 for (const auto idx : c10::irange (num_arguments)) {
356398 auto stable_ivalue = stack[idx];
357399 auto arg_type = schema.arguments ()[idx].type ();
358- torch::jit::push (ivalue_stack, to_ivalue (arg_type, stable_ivalue));
400+ torch::jit::push (
401+ ivalue_stack,
402+ to_ivalue (arg_type, stable_ivalue, extension_build_version));
359403 }
360404 }
361405
@@ -366,7 +410,8 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
366410 for (const auto idx : c10::irange (num_returns)) {
367411 const auto stack_idx = num_returns - idx - 1 ;
368412 const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
369- stack[stack_idx] = from_ivalue (ret_type, torch::jit::pop (ivalue_stack));
413+ stack[stack_idx] = from_ivalue (
414+ ret_type, torch::jit::pop (ivalue_stack), extension_build_version);
370415 }
371416 });
372417}
0 commit comments