|
7 | 7 |
|
8 | 8 | #include <functorch/csrc/BatchRulesHelper.h>
|
9 | 9 | #include <ATen/Operators.h>
|
| 10 | +#include <ATen/FunctionalTensorWrapper.h> |
10 | 11 | #include <functorch/csrc/PlumbingHelper.h>
|
11 | 12 | #include <functorch/csrc/BatchedFallback.h>
|
12 | 13 | #include <ATen/core/dispatch/Dispatcher.h>
|
13 | 14 |
|
14 | 15 | namespace at { namespace functorch {
|
15 | 16 |
|
| 17 | +at::Tensor sync_and_unwrap_functional_output(at::Tensor out_functional) { |
| 18 | + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(out_functional)); |
| 19 | + auto out_wrapper_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(out_functional); |
| 20 | + out_wrapper_impl->sync_(); |
| 21 | + auto out_unwrapped = out_wrapper_impl->value(); |
| 22 | + return out_unwrapped; |
| 23 | +} |
| 24 | + |
| 25 | +c10::List<at::Tensor> sync_and_unwrap_functional_output(const c10::List<at::Tensor>& t_list) { |
| 26 | + c10::List<Tensor> outputs; |
| 27 | + outputs.reserve(t_list.size()); |
| 28 | + for (const auto i : c10::irange(t_list.size())) { |
| 29 | + outputs.push_back(sync_and_unwrap_functional_output(t_list[i])); |
| 30 | + } |
| 31 | + return outputs; |
| 32 | +} |
| 33 | + |
| 34 | +void decompose_functional(const c10::OperatorHandle& op, torch::jit::Stack* stack) { |
| 35 | + const auto& schema = op.schema(); |
| 36 | + |
| 37 | + const auto num_arguments = schema.arguments().size(); |
| 38 | + const auto arguments = torch::jit::last(stack, num_arguments); |
| 39 | + const auto arguments_begin = stack->size() - num_arguments; |
| 40 | + // |
| 41 | + // Step 1: Wrap any tensor inputs into Functional tensors |
| 42 | + // and put them on the stack at the correct indices. |
| 43 | + for (const auto idx : c10::irange(arguments.size())) { |
| 44 | + const auto& ivalue = arguments[idx]; |
| 45 | + if (ivalue.isTensor()) { |
| 46 | + auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensor()); |
| 47 | + (*stack)[arguments_begin + idx] = std::move(functional_ivalue); |
| 48 | + } else if (ivalue.isTensorList()) { |
| 49 | + auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensorList()); |
| 50 | + (*stack)[arguments_begin + idx] = std::move(functional_ivalue); |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + // Step 2: set up TLS such that we hit the functionalization kernels before the batching rules. |
| 55 | + // Note: this relies on the fact that Functionalization > BatchMode in DispatchKey.h |
| 56 | + c10::impl::IncludeDispatchKeyGuard include_guard(c10::DispatchKeySet(c10::DispatchKey::Functionalize)); |
| 57 | + |
| 58 | + // Step 3: redispatch to native kernel |
| 59 | + // TODO: this is technically kind of sketchy, since we're relying on the fact |
| 60 | + // that the composite kernel is registered to a particular dispatch key. |
| 61 | + // In reality, a C++ extension could register their own custom kernels to any dispatch key, which would override |
| 62 | + // the composite kernel entry. |
| 63 | + // I'm using CPU because C++ extensions that register custom kernels to existing composite operators are pretty uncommon, |
| 64 | + // and only really matter for out-of-tree keys like XLA. |
| 65 | + // I wonder if we should make "alias dispatch key kernels" a runtime-accessible property on the OperatorHandle? |
| 66 | + op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack); |
| 67 | + |
| 68 | + const auto& schema_returns = op.schema().returns(); |
| 69 | + const auto& num_returns = schema_returns.size(); |
| 70 | + auto returns = torch::jit::last(stack, num_returns); |
| 71 | + const auto returns_begin = stack->size() - num_returns; |
| 72 | + |
| 73 | + // Step 4: Unwrap each functional output tensor, syncing any pending updates |
| 74 | + for (const auto idx : c10::irange(returns.size())) { |
| 75 | + if (returns[idx].isTensor()) { |
| 76 | + const auto& out_functional = returns[idx].toTensor(); |
| 77 | + auto out_unwrapped = sync_and_unwrap_functional_output(out_functional); |
| 78 | + (*stack)[returns_begin + idx] = c10::IValue(out_unwrapped); |
| 79 | + } else if (returns[idx].isTensorList()) { |
| 80 | + const auto& out_functional = returns[idx].toTensorList(); |
| 81 | + auto out_unwrapped = sync_and_unwrap_functional_output(out_functional); |
| 82 | + (*stack)[returns_begin + idx] = c10::IValue(out_unwrapped); |
| 83 | + } |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +#define DECOMPOSE_FUNCTIONAL(op) \ |
| 88 | + m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&decompose_functional>()); |
| 89 | + |
| 90 | + |
16 | 91 | #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
17 | 92 | #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
18 | 93 |
|
@@ -149,6 +224,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
149 | 224 | OP_DECOMPOSE(arctan2);
|
150 | 225 | OP_DECOMPOSE(layer_norm);
|
151 | 226 | OP_DECOMPOSE(diag_backward);
|
| 227 | + DECOMPOSE_FUNCTIONAL(diag_embed); |
| 228 | + DECOMPOSE_FUNCTIONAL(block_diag); |
152 | 229 | }
|
153 | 230 |
|
154 | 231 | }}
|
|
0 commit comments