Skip to content

Commit 8f51556

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add scaffolding for aoti_torch_call_dispatcher BC with native ops (pytorch#163683)
Part 1 of plan in https://docs.google.com/document/d/1MaX51H5aEQE5XnOlnZIpf9oCYwzGrTWkgBACxNzsmWE/edit?usp=sharing - Upgrade `aoti_torch_call_dispatcher` to v2 with an `extension_build_version` - Allow registration of StableIValue stack --> IValue stack adapters for schema changes #### Note: This PR does not include a linter that tells the user to add the upgrader if the schema changes, which is an important piece that will be added in a separate PR Pull Request resolved: pytorch#163683 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#164356, pytorch#166373
1 parent c0bbda3 commit 8f51556

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

torch/csrc/shim_common.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <torch/csrc/stable/library.h>
99
#include <torch/library.h>
1010

11+
#include <torch/csrc/stable/c/shim.h>
12+
1113
static StableIValue from_ivalue(
1214
const c10::TypePtr& type,
1315
const c10::IValue& ivalue) {
@@ -216,3 +218,155 @@ AOTITorchError aoti_torch_call_dispatcher(
216218
}
217219
});
218220
}
221+
222+
// Schema Adapter Infrastructure
223+
// SchemaAdapterRegistry contains the adapters registered via
224+
// register_schema_adapter that define how to convert the StableIValue argument
225+
// stack to an IValue stack when changes are made to the schema of an ATen
226+
// function. This should only be relevant in the context of calling
227+
// torch_call_dispatcher.
228+
229+
// Currently this only adapts the argument stack.
230+
// C++ default argument resolution will happen at compile time in the
231+
// torch/csrc/stable/ops.h header, so extensions always pass complete argument
232+
// lists for the version they build against's schema. As such, this is only
233+
// needed if a new argument is added to the schema
234+
//
235+
// This is not declared in the stable shim.h,
236+
// so we **do not make any guarantees that the signature of this will not
237+
// change**. If there is a need to define similar infrastructure for the returns
238+
// of an aten function we can update this.
239+
240+
namespace {
241+
using SchemaAdapterFn = std::function<torch::jit::Stack(
242+
const c10::FunctionSchema& current_schema,
243+
const StableIValue* extension_stack,
244+
uint64_t extension_build_version)>;
245+
246+
// Global registry for schema adapters
247+
class SchemaAdapterRegistry {
248+
private:
249+
std::unordered_map<
250+
std::string,
251+
std::vector<std::pair<uint64_t, SchemaAdapterFn>>>
252+
adapters_;
253+
254+
public:
255+
static SchemaAdapterRegistry& instance() {
256+
static SchemaAdapterRegistry registry;
257+
return registry;
258+
}
259+
260+
void register_adapter(
261+
const std::string& op_name,
262+
uint64_t
263+
applies_to_versions_below, // versions below this need the adapter
264+
SchemaAdapterFn adapter) {
265+
adapters_[op_name].emplace_back(applies_to_versions_below, adapter);
266+
// Sort by version ascending - this allows us to find the first (most
267+
// specific) match
268+
std::sort(
269+
adapters_[op_name].begin(),
270+
adapters_[op_name].end(),
271+
[](const auto& a, const auto& b) { return a.first < b.first; });
272+
}
273+
274+
std::optional<SchemaAdapterFn> get_adapter(
275+
const std::string& op_name,
276+
uint64_t extension_version) {
277+
auto it = adapters_.find(op_name);
278+
if (it == adapters_.end())
279+
return std::nullopt;
280+
281+
// Find the first adapter that applies (most specific due to ascending sort)
282+
for (const auto& [applies_to_versions_below, adapter] : it->second) {
283+
if (extension_version < applies_to_versions_below) {
284+
return adapter;
285+
}
286+
}
287+
return std::nullopt;
288+
}
289+
};
290+
291+
// Internal API for registering adapters that define how to convert the
292+
// StableIValue **argument** stack to an IValue stack when changes are
293+
// made to the schema of a function. adapter_fn will be used if
294+
// extension_build_version < applies_to_versions_below.
295+
[[maybe_unused]] AOTITorchError register_schema_adapter(
296+
const char* op_name,
297+
uint64_t applies_to_versions_below,
298+
SchemaAdapterFn adapter_fn) {
299+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
300+
auto& registry = SchemaAdapterRegistry::instance();
301+
registry.register_adapter(
302+
std::string(op_name), applies_to_versions_below, std::move(adapter_fn));
303+
});
304+
}
305+
306+
} // namespace
307+
308+
// Function to register test schema adapters for _test_schema_upgrader
309+
// This demonstrates the adapter registration pattern (internal use only)
310+
static AOTITorchError _register_adapters() {
311+
// ** Schema adapters should be registered here**
312+
// Refer to https://github.com/pytorch/pytorch/pull/165284/ for an example.
313+
//
314+
// if (auto err = register_schema_adapter(
315+
// "aten::your_op",
316+
// VERSION_FOO, // applies to versions < VERSION_FOO
317+
// adapt_v1_to_vfoo)) {
318+
// return err;
319+
// }
320+
return AOTI_TORCH_SUCCESS;
321+
}
322+
323+
// Static initialization to automatically register test adapters
324+
static struct AdapterInitializer {
325+
AdapterInitializer() {
326+
// Register the test adapters when the library loads
327+
_register_adapters();
328+
}
329+
} adapter_initializer;
330+
331+
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
332+
const char* opName,
333+
const char* overloadName,
334+
StableIValue* stack,
335+
// version of stable headers used to build the extension: necessary for
336+
// applying schema adapters
337+
uint64_t extension_build_version) {
338+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
339+
const auto op =
340+
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
341+
const auto& schema = op.schema();
342+
const auto num_returns = schema.returns().size();
343+
const auto num_arguments = schema.arguments().size();
344+
345+
torch::jit::Stack ivalue_stack;
346+
auto& registry = SchemaAdapterRegistry::instance();
347+
348+
// Check if we need an adapter for this operation
349+
if (auto adapter = registry.get_adapter(opName, extension_build_version)) {
350+
// Use adapter to create IValue stack
351+
ivalue_stack = (*adapter)(schema, stack, extension_build_version);
352+
} else {
353+
// No adapter needed - implementation matches aoti_torch_call_dispatcher
354+
ivalue_stack.reserve(std::max(num_arguments, num_returns));
355+
for (const auto idx : c10::irange(num_arguments)) {
356+
auto stable_ivalue = stack[idx];
357+
auto arg_type = schema.arguments()[idx].type();
358+
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
359+
}
360+
}
361+
362+
op.callBoxed(ivalue_stack);
363+
364+
// there should then be num_returns IValues on the stack, which
365+
// we will convert to StableIValue and repopulate user input stack
366+
for (const auto idx : c10::irange(num_returns)) {
367+
const auto stack_idx = num_returns - idx - 1;
368+
const c10::TypePtr& ret_type = schema.returns()[idx].type();
369+
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
370+
}
371+
});
372+
}

torch/csrc/stable/c/shim.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef STABLE_TORCH_SHIM
2+
#define STABLE_TORCH_SHIM
3+
4+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
5+
6+
#include <torch/csrc/stable/version.h>
7+
8+
// This header defines stable C API extensions for backward/forward
9+
// compatibility when calling ATen operations through the dispatcher.
10+
//
11+
// This is separate from the main AOTI shim to provide versioning capabilities
12+
// for schema changes in native ATen functions.
13+
14+
#ifdef __cplusplus
15+
extern "C" {
16+
#endif
17+
18+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
19+
using StableIValue = uint64_t;
20+
21+
// Has the same semantic as aoti_torch_call_dispatcher, but takes an
22+
// additional argument for the extension build version. This is
23+
// needed for backward compatibility when calling native functions via
24+
// the dispatcher. The caller should pass in the libtorch version the
25+
// extension is building with (NOT target version).
26+
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
27+
const char* opName,
28+
const char* overloadName,
29+
StableIValue* stack,
30+
uint64_t extension_build_version);
31+
32+
// Version-aware variant of aoti_torch_library_impl that takes an
33+
// extension_build_version parameter for backward compatibility
34+
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
35+
TorchLibraryHandle self,
36+
const char* name,
37+
void (*fn)(StableIValue*, uint64_t, uint64_t),
38+
uint64_t extension_build_version);
39+
40+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
41+
42+
#ifdef __cplusplus
43+
} // extern "C"
44+
#endif
45+
46+
#endif // STABLE_TORCH_SHIM

torch/csrc/stable/ops.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <vector>
99

1010
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
11+
#include <torch/csrc/stable/c/shim.h>
12+
#include <torch/csrc/stable/version.h>
1113
#include <torch/headeronly/core/ScalarType.h>
1214
#include <torch/headeronly/macros/Macros.h>
1315

@@ -25,8 +27,13 @@ inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) {
2527
torch::stable::detail::from(std::nullopt),
2628
torch::stable::detail::from(std::nullopt),
2729
torch::stable::detail::from(std::nullopt)};
30+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
31+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
32+
"aten::empty_like", "", stack.data(), TORCH_ABI_VERSION));
33+
#else
2834
TORCH_ERROR_CODE_CHECK(
2935
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
36+
#endif
3037
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
3138
}
3239

@@ -201,8 +208,13 @@ inline torch::stable::Tensor transpose(
201208
torch::stable::detail::from(self),
202209
torch::stable::detail::from(dim0),
203210
torch::stable::detail::from(dim1)};
211+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
212+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
213+
"aten::transpose", "int", stack.data(), TORCH_ABI_VERSION));
214+
#else
204215
TORCH_ERROR_CODE_CHECK(
205216
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
217+
#endif
206218
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
207219
}
208220

@@ -212,8 +224,13 @@ inline torch::stable::Tensor transpose(
212224
inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
213225
const auto num_args = 1;
214226
std::array<StableIValue, num_args> stack{torch::stable::detail::from(self)};
227+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
228+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
229+
"aten::zero_", "", stack.data(), TORCH_ABI_VERSION));
230+
#else
215231
TORCH_ERROR_CODE_CHECK(
216232
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
233+
#endif
217234
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
218235
}
219236

@@ -228,8 +245,13 @@ inline torch::stable::Tensor copy_(
228245
torch::stable::detail::from(self),
229246
torch::stable::detail::from(src),
230247
torch::stable::detail::from(non_blocking.value_or(false))};
248+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
249+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
250+
"aten::copy_", "", stack.data(), TORCH_ABI_VERSION));
251+
#else
231252
TORCH_ERROR_CODE_CHECK(
232253
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
254+
#endif
233255
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
234256
}
235257

@@ -240,8 +262,13 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
240262
std::array<StableIValue, num_args> stack{
241263
torch::stable::detail::from(self),
242264
torch::stable::detail::from(std::nullopt)};
265+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
266+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
267+
"aten::clone", "", stack.data(), TORCH_ABI_VERSION));
268+
#else
243269
TORCH_ERROR_CODE_CHECK(
244270
aoti_torch_call_dispatcher("aten::clone", "", stack.data()));
271+
#endif
245272
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
246273
}
247274

0 commit comments

Comments
 (0)