Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 19fa4a1

Browse files
authored
add conditional functionalization (#235)
* add conditional functionalization * fix an RAII bug </3 * fix an RAII bug </3 * remove swap files * also functionalize block_diag * add a vmap(vmap(diag_embed)) test, update usage of sync() in the conditional fallback logic
1 parent b8fc6d0 commit 19fa4a1

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,87 @@
77

88
#include <functorch/csrc/BatchRulesHelper.h>
99
#include <ATen/Operators.h>
10+
#include <ATen/FunctionalTensorWrapper.h>
1011
#include <functorch/csrc/PlumbingHelper.h>
1112
#include <functorch/csrc/BatchedFallback.h>
1213
#include <ATen/core/dispatch/Dispatcher.h>
1314

1415
namespace at { namespace functorch {
1516

17+
at::Tensor sync_and_unwrap_functional_output(at::Tensor out_functional) {
18+
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(out_functional));
19+
auto out_wrapper_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(out_functional);
20+
out_wrapper_impl->sync_();
21+
auto out_unwrapped = out_wrapper_impl->value();
22+
return out_unwrapped;
23+
}
24+
25+
c10::List<at::Tensor> sync_and_unwrap_functional_output(const c10::List<at::Tensor>& t_list) {
26+
c10::List<Tensor> outputs;
27+
outputs.reserve(t_list.size());
28+
for (const auto i : c10::irange(t_list.size())) {
29+
outputs.push_back(sync_and_unwrap_functional_output(t_list[i]));
30+
}
31+
return outputs;
32+
}
33+
34+
void decompose_functional(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
35+
const auto& schema = op.schema();
36+
37+
const auto num_arguments = schema.arguments().size();
38+
const auto arguments = torch::jit::last(stack, num_arguments);
39+
const auto arguments_begin = stack->size() - num_arguments;
40+
//
41+
// Step 1: Wrap any tensor inputs into Functional tensors
42+
// and put them on the stack at the correct indices.
43+
for (const auto idx : c10::irange(arguments.size())) {
44+
const auto& ivalue = arguments[idx];
45+
if (ivalue.isTensor()) {
46+
auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensor());
47+
(*stack)[arguments_begin + idx] = std::move(functional_ivalue);
48+
} else if (ivalue.isTensorList()) {
49+
auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensorList());
50+
(*stack)[arguments_begin + idx] = std::move(functional_ivalue);
51+
}
52+
}
53+
54+
// Step 2: set up TLS such that we hit the functionalization kernels before the batching rules.
55+
// Note: this relies on the fact that Functionalization > BatchMode in DispatchKey.h
56+
c10::impl::IncludeDispatchKeyGuard include_guard(c10::DispatchKeySet(c10::DispatchKey::Functionalize));
57+
58+
// Step 3: redispatch to native kernel
59+
// TODO: this is technically kind of sketchy, since we're relying on the fact
60+
// that the composite kernel is registered to a particular dispatch key.
61+
// In reality, a C++ extension could register their own custom kernels to any dispatch key, which would override
62+
// the composite kernel entry.
63+
// I'm using CPU because C++ extensions that register custom kernels to existing composite operators are pretty uncommon,
64+
// and only really matter for out-of-tree keys like XLA.
65+
// I wonder if we should make "alias dispatch key kernels" a runtime-accessible property on the OperatorHandle?
66+
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);
67+
68+
const auto& schema_returns = op.schema().returns();
69+
const auto& num_returns = schema_returns.size();
70+
auto returns = torch::jit::last(stack, num_returns);
71+
const auto returns_begin = stack->size() - num_returns;
72+
73+
// Step 4: Unwrap each functional output tensor, syncing any pending updates
74+
for (const auto idx : c10::irange(returns.size())) {
75+
if (returns[idx].isTensor()) {
76+
const auto& out_functional = returns[idx].toTensor();
77+
auto out_unwrapped = sync_and_unwrap_functional_output(out_functional);
78+
(*stack)[returns_begin + idx] = c10::IValue(out_unwrapped);
79+
} else if (returns[idx].isTensorList()) {
80+
const auto& out_functional = returns[idx].toTensorList();
81+
auto out_unwrapped = sync_and_unwrap_functional_output(out_functional);
82+
(*stack)[returns_begin + idx] = c10::IValue(out_unwrapped);
83+
}
84+
}
85+
}
86+
87+
#define DECOMPOSE_FUNCTIONAL(op) \
88+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&decompose_functional>());
89+
90+
1691
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
1792
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
1893

@@ -149,6 +224,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
149224
OP_DECOMPOSE(arctan2);
150225
OP_DECOMPOSE(layer_norm);
151226
OP_DECOMPOSE(diag_backward);
227+
DECOMPOSE_FUNCTIONAL(diag_embed);
228+
DECOMPOSE_FUNCTIONAL(block_diag);
152229
}
153230

154231
}}

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ def test():
749749
xfail('index_put'),
750750
xfail('linalg.multi_dot'),
751751
xfail('vstack'),
752-
xfail('block_diag'),
753752
xfail('nn.functional.batch_norm'),
754753
xfail('cdist'),
755754
xfail('lu_solve'),

test/test_vmap.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def test_nested_with_same_map_dim(self):
143143
output = vmap(vmap(vmap(torch.mul)))(x, y)
144144
self.assertEqual(output, x * y)
145145

146+
def test_nested_with_diag_embed(self):
147+
# diag_embed requires special testing because it is registered with conditional functionalization.
148+
x = torch.randn(3, 3, 5)
149+
output = vmap(vmap(torch.diag_embed))(x)
150+
self.assertEqual(output, torch.diag_embed(x))
151+
146152
def test_nested_with_different_map_dim(self):
147153
x = torch.randn(2, 3)
148154
y = torch.randn(5, 3)
@@ -3089,7 +3095,6 @@ class TestVmapOperatorsOpInfo(TestCase):
30893095
xfail('vstack'),
30903096
xfail('dstack'),
30913097
xfail('linalg.multi_dot'),
3092-
xfail('block_diag'),
30933098
xfail('nn.functional.dropout'),
30943099
xfail('view_as_complex'),
30953100
xfail('H'),

0 commit comments

Comments
 (0)