|
8 | 8 | #include <torch/csrc/stable/library.h> |
9 | 9 | #include <torch/library.h> |
10 | 10 |
|
| 11 | +#include <torch/csrc/stable/c/shim.h> |
| 12 | + |
11 | 13 | static StableIValue from_ivalue( |
12 | 14 | const c10::TypePtr& type, |
13 | 15 | const c10::IValue& ivalue) { |
@@ -216,3 +218,155 @@ AOTITorchError aoti_torch_call_dispatcher( |
216 | 218 | } |
217 | 219 | }); |
218 | 220 | } |
| 221 | + |
| 222 | +// Schema Adapter Infrastructure |
| 223 | +// SchemaAdapterRegistry contains the adapters registered via |
| 224 | +// register_schema_adapter that define how to convert the StableIValue argument |
| 225 | +// stack to an IValue stack when changes are made to the schema of an ATen |
| 226 | +// function. This should only be relevant in the context of calling |
| 227 | +// torch_call_dispatcher. |
| 228 | + |
| 229 | +// Currently this only adapts the argument stack. |
| 230 | +// C++ default argument resolution will happen at compile time in the |
| 231 | +// torch/csrc/stable/ops.h header, so extensions always pass complete argument |
| 232 | +// lists for the version they build against's schema. As such, this is only |
| 233 | +// needed if a new argument is added to the schema |
| 234 | +// |
| 235 | +// This is not declared in the stable shim.h, |
| 236 | +// so we **do not make any guarantees that the signature of this will not |
| 237 | +// change**. If there is a need to define similar infrastructure for the returns |
| 238 | +// of an aten function we can update this. |
| 239 | + |
| 240 | +namespace { |
| 241 | +using SchemaAdapterFn = std::function<torch::jit::Stack( |
| 242 | + const c10::FunctionSchema& current_schema, |
| 243 | + const StableIValue* extension_stack, |
| 244 | + uint64_t extension_build_version)>; |
| 245 | + |
| 246 | +// Global registry for schema adapters |
| 247 | +class SchemaAdapterRegistry { |
| 248 | + private: |
| 249 | + std::unordered_map< |
| 250 | + std::string, |
| 251 | + std::vector<std::pair<uint64_t, SchemaAdapterFn>>> |
| 252 | + adapters_; |
| 253 | + |
| 254 | + public: |
| 255 | + static SchemaAdapterRegistry& instance() { |
| 256 | + static SchemaAdapterRegistry registry; |
| 257 | + return registry; |
| 258 | + } |
| 259 | + |
| 260 | + void register_adapter( |
| 261 | + const std::string& op_name, |
| 262 | + uint64_t |
| 263 | + applies_to_versions_below, // versions below this need the adapter |
| 264 | + SchemaAdapterFn adapter) { |
| 265 | + adapters_[op_name].emplace_back(applies_to_versions_below, adapter); |
| 266 | + // Sort by version ascending - this allows us to find the first (most |
| 267 | + // specific) match |
| 268 | + std::sort( |
| 269 | + adapters_[op_name].begin(), |
| 270 | + adapters_[op_name].end(), |
| 271 | + [](const auto& a, const auto& b) { return a.first < b.first; }); |
| 272 | + } |
| 273 | + |
| 274 | + std::optional<SchemaAdapterFn> get_adapter( |
| 275 | + const std::string& op_name, |
| 276 | + uint64_t extension_version) { |
| 277 | + auto it = adapters_.find(op_name); |
| 278 | + if (it == adapters_.end()) |
| 279 | + return std::nullopt; |
| 280 | + |
| 281 | + // Find the first adapter that applies (most specific due to ascending sort) |
| 282 | + for (const auto& [applies_to_versions_below, adapter] : it->second) { |
| 283 | + if (extension_version < applies_to_versions_below) { |
| 284 | + return adapter; |
| 285 | + } |
| 286 | + } |
| 287 | + return std::nullopt; |
| 288 | + } |
| 289 | +}; |
| 290 | + |
| 291 | +// Internal API for registering adapters that define how to convert the |
| 292 | +// StableIValue **argument** stack to an IValue stack when changes are |
| 293 | +// made to the schema of a function. adapter_fn will be used if |
| 294 | +// extension_build_version < applies_to_versions_below. |
| 295 | +[[maybe_unused]] AOTITorchError register_schema_adapter( |
| 296 | + const char* op_name, |
| 297 | + uint64_t applies_to_versions_below, |
| 298 | + SchemaAdapterFn adapter_fn) { |
| 299 | + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ |
| 300 | + auto& registry = SchemaAdapterRegistry::instance(); |
| 301 | + registry.register_adapter( |
| 302 | + std::string(op_name), applies_to_versions_below, std::move(adapter_fn)); |
| 303 | + }); |
| 304 | +} |
| 305 | + |
| 306 | +} // namespace |
| 307 | + |
| 308 | +// Function to register test schema adapters for _test_schema_upgrader |
| 309 | +// This demonstrates the adapter registration pattern (internal use only) |
| 310 | +static AOTITorchError _register_adapters() { |
| 311 | + // ** Schema adapters should be registered here** |
| 312 | + // Refer to https://github.com/pytorch/pytorch/pull/165284/ for an example. |
| 313 | + // |
| 314 | + // if (auto err = register_schema_adapter( |
| 315 | + // "aten::your_op", |
| 316 | + // VERSION_FOO, // applies to versions < VERSION_FOO |
| 317 | + // adapt_v1_to_vfoo)) { |
| 318 | + // return err; |
| 319 | + // } |
| 320 | + return AOTI_TORCH_SUCCESS; |
| 321 | +} |
| 322 | + |
| 323 | +// Static initialization to automatically register test adapters |
| 324 | +static struct AdapterInitializer { |
| 325 | + AdapterInitializer() { |
| 326 | + // Register the test adapters when the library loads |
| 327 | + _register_adapters(); |
| 328 | + } |
| 329 | +} adapter_initializer; |
| 330 | + |
| 331 | +AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher( |
| 332 | + const char* opName, |
| 333 | + const char* overloadName, |
| 334 | + StableIValue* stack, |
| 335 | + // version of stable headers used to build the extension: necessary for |
| 336 | + // applying schema adapters |
| 337 | + uint64_t extension_build_version) { |
| 338 | + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ |
| 339 | + const auto op = |
| 340 | + c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName); |
| 341 | + const auto& schema = op.schema(); |
| 342 | + const auto num_returns = schema.returns().size(); |
| 343 | + const auto num_arguments = schema.arguments().size(); |
| 344 | + |
| 345 | + torch::jit::Stack ivalue_stack; |
| 346 | + auto& registry = SchemaAdapterRegistry::instance(); |
| 347 | + |
| 348 | + // Check if we need an adapter for this operation |
| 349 | + if (auto adapter = registry.get_adapter(opName, extension_build_version)) { |
| 350 | + // Use adapter to create IValue stack |
| 351 | + ivalue_stack = (*adapter)(schema, stack, extension_build_version); |
| 352 | + } else { |
| 353 | + // No adapter needed - implementation matches aoti_torch_call_dispatcher |
| 354 | + ivalue_stack.reserve(std::max(num_arguments, num_returns)); |
| 355 | + for (const auto idx : c10::irange(num_arguments)) { |
| 356 | + auto stable_ivalue = stack[idx]; |
| 357 | + auto arg_type = schema.arguments()[idx].type(); |
| 358 | + torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue)); |
| 359 | + } |
| 360 | + } |
| 361 | + |
| 362 | + op.callBoxed(ivalue_stack); |
| 363 | + |
| 364 | + // there should then be num_returns IValues on the stack, which |
| 365 | + // we will convert to StableIValue and repopulate user input stack |
| 366 | + for (const auto idx : c10::irange(num_returns)) { |
| 367 | + const auto stack_idx = num_returns - idx - 1; |
| 368 | + const c10::TypePtr& ret_type = schema.returns()[idx].type(); |
| 369 | + stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack)); |
| 370 | + } |
| 371 | + }); |
| 372 | +} |
0 commit comments