|
15 | 15 |
|
16 | 16 | namespace at { namespace functorch {
|
17 | 17 |
|
18 |
| -at::Tensor sync_and_unwrap_functional_output(at::Tensor out_functional) { |
19 |
| - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(out_functional)); |
20 |
| - auto out_wrapper_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(out_functional); |
21 |
| - out_wrapper_impl->sync_(); |
22 |
| - auto out_unwrapped = out_wrapper_impl->value(); |
23 |
| - return out_unwrapped; |
24 |
| -} |
25 |
| - |
26 |
| -c10::List<at::Tensor> sync_and_unwrap_functional_output(const c10::List<at::Tensor>& t_list) { |
27 |
| - c10::List<Tensor> outputs; |
28 |
| - outputs.reserve(t_list.size()); |
29 |
| - for (const auto i : c10::irange(t_list.size())) { |
30 |
| - outputs.push_back(sync_and_unwrap_functional_output(t_list[i])); |
31 |
| - } |
32 |
| - return outputs; |
33 |
| -} |
34 |
| - |
35 |
| -void decompose_functional(const c10::OperatorHandle& op, torch::jit::Stack* stack) { |
36 |
| - const auto& schema = op.schema(); |
37 |
| - |
38 |
| - const auto num_arguments = schema.arguments().size(); |
39 |
| - const auto arguments = torch::jit::last(stack, num_arguments); |
40 |
| - const auto arguments_begin = stack->size() - num_arguments; |
41 |
| - // |
42 |
| - // Step 1: Wrap any tensor inputs into Functional tensors |
43 |
| - // and put them on the stack at the correct indices. |
44 |
| - for (const auto idx : c10::irange(arguments.size())) { |
45 |
| - const auto& ivalue = arguments[idx]; |
46 |
| - if (ivalue.isTensor()) { |
47 |
| - auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensor()); |
48 |
| - (*stack)[arguments_begin + idx] = std::move(functional_ivalue); |
49 |
| - } else if (ivalue.isTensorList()) { |
50 |
| - auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensorList()); |
51 |
| - (*stack)[arguments_begin + idx] = std::move(functional_ivalue); |
52 |
| - } |
53 |
| - } |
54 |
| - |
55 |
| - // Step 2: set up TLS such that we hit the functionalization kernels before the batching rules. |
56 |
| - // Note: this relies on the fact that Functionalize > FuncTorchBatched in DispatchKey.h. |
57 |
| - // Also, adding Functionalize to the include set isn't enough: we also need to remove it from the exclude set. |
58 |
| - // That's because functorch DynamicLayer logic may have added Functionalize to the exclude set beforehand. |
59 |
| - auto local_keyset = c10::impl::tls_local_dispatch_key_set(); |
60 |
| - local_keyset.excluded_ = local_keyset.excluded_.remove(c10::DispatchKey::Functionalize); |
61 |
| - local_keyset.included_ = local_keyset.included_.add(c10::DispatchKey::Functionalize); |
62 |
| - c10::impl::ForceDispatchKeyGuard guard(local_keyset); |
63 |
| - |
64 |
| - at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(true); |
65 |
| - |
66 |
| - // Step 3: redispatch to native kernel |
67 |
| - // TODO: this is technically kind of sketchy, since we're relying on the fact |
68 |
| - // that the composite kernel is registered to a particular dispatch key. |
69 |
| - // In reality, a C++ extension could register their own custom kernels to any dispatch key, which would override |
70 |
| - // the composite kernel entry. |
71 |
| - // I'm using CPU because C++ extensions that register custom kernels to existing composite operators are pretty uncommon, |
72 |
| - // and only really matter for out-of-tree keys like XLA. |
73 |
| - // I wonder if we should make "alias dispatch key kernels" a runtime-accessible property on the OperatorHandle? |
74 |
| - op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack); |
75 |
| - |
76 |
| - const auto& schema_returns = op.schema().returns(); |
77 |
| - const auto& num_returns = schema_returns.size(); |
78 |
| - auto returns = torch::jit::last(stack, num_returns); |
79 |
| - const auto returns_begin = stack->size() - num_returns; |
80 |
| - |
81 |
| - // Step 4: Unwrap each functional output tensor, syncing any pending updates |
82 |
| - for (const auto idx : c10::irange(returns.size())) { |
83 |
| - if (returns[idx].isTensor()) { |
84 |
| - const auto& out_functional = returns[idx].toTensor(); |
85 |
| - auto out_unwrapped = sync_and_unwrap_functional_output(out_functional); |
86 |
| - (*stack)[returns_begin + idx] = c10::IValue(out_unwrapped); |
87 |
| - } else if (returns[idx].isTensorList()) { |
88 |
| - const auto& out_functional = returns[idx].toTensorList(); |
89 |
| - auto out_unwrapped = sync_and_unwrap_functional_output(out_functional); |
90 |
| - (*stack)[returns_begin + idx] = c10::IValue(out_unwrapped); |
91 |
| - } |
92 |
| - } |
93 |
| -} |
94 |
| - |
95 |
| - |
96 |
| -#define DECOMPOSE_FUNCTIONAL(op) \ |
97 |
| - m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&decompose_functional>()); |
98 |
| - |
99 |
| - |
100 | 18 | #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
101 | 19 | #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
102 | 20 |
|
@@ -315,8 +233,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
315 | 233 | OP_DECOMPOSE(pad);
|
316 | 234 | OP_DECOMPOSE(_pad_circular);
|
317 | 235 |
|
318 |
| - DECOMPOSE_FUNCTIONAL(block_diag); |
319 |
| - |
320 | 236 | // divide, alias for div
|
321 | 237 | OP_DECOMPOSE2(divide, Tensor);
|
322 | 238 | OP_DECOMPOSE2(divide_, Tensor);
|
|
0 commit comments