Skip to content

Commit abe4c4d

Browse files
authored
add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL (#814)
* add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL * add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL
1 parent 9fa0265 commit abe4c4d

File tree

4 files changed

+29
-86
lines changed

4 files changed

+29
-86
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -15,88 +15,6 @@
1515

1616
namespace at { namespace functorch {
1717

18-
at::Tensor sync_and_unwrap_functional_output(at::Tensor out_functional) {
19-
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(out_functional));
20-
auto out_wrapper_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(out_functional);
21-
out_wrapper_impl->sync_();
22-
auto out_unwrapped = out_wrapper_impl->value();
23-
return out_unwrapped;
24-
}
25-
26-
c10::List<at::Tensor> sync_and_unwrap_functional_output(const c10::List<at::Tensor>& t_list) {
27-
c10::List<Tensor> outputs;
28-
outputs.reserve(t_list.size());
29-
for (const auto i : c10::irange(t_list.size())) {
30-
outputs.push_back(sync_and_unwrap_functional_output(t_list[i]));
31-
}
32-
return outputs;
33-
}
34-
35-
void decompose_functional(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
36-
const auto& schema = op.schema();
37-
38-
const auto num_arguments = schema.arguments().size();
39-
const auto arguments = torch::jit::last(stack, num_arguments);
40-
const auto arguments_begin = stack->size() - num_arguments;
41-
//
42-
// Step 1: Wrap any tensor inputs into Functional tensors
43-
// and put them on the stack at the correct indices.
44-
for (const auto idx : c10::irange(arguments.size())) {
45-
const auto& ivalue = arguments[idx];
46-
if (ivalue.isTensor()) {
47-
auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensor());
48-
(*stack)[arguments_begin + idx] = std::move(functional_ivalue);
49-
} else if (ivalue.isTensorList()) {
50-
auto functional_ivalue = at::functionalization::impl::to_functional_tensor(ivalue.toTensorList());
51-
(*stack)[arguments_begin + idx] = std::move(functional_ivalue);
52-
}
53-
}
54-
55-
// Step 2: set up TLS such that we hit the functionalization kernels before the batching rules.
56-
// Note: this relies on the fact that Functionalize > FuncTorchBatched in DispatchKey.h.
57-
// Also, adding Functionalize to the include set isn't enough: we also need to remove it from the exclude set.
58-
// That's because functorch DynamicLayer logic may have added Functionalize to the exclude set beforehand.
59-
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
60-
local_keyset.excluded_ = local_keyset.excluded_.remove(c10::DispatchKey::Functionalize);
61-
local_keyset.included_ = local_keyset.included_.add(c10::DispatchKey::Functionalize);
62-
c10::impl::ForceDispatchKeyGuard guard(local_keyset);
63-
64-
at::functionalization::impl::FunctionalizationReapplyViewsGuard functional_guard(true);
65-
66-
// Step 3: redispatch to native kernel
67-
// TODO: this is technically kind of sketchy, since we're relying on the fact
68-
// that the composite kernel is registered to a particular dispatch key.
69-
// In reality, a C++ extension could register their own custom kernels to any dispatch key, which would override
70-
// the composite kernel entry.
71-
// I'm using CPU because C++ extensions that register custom kernels to existing composite operators are pretty uncommon,
72-
// and only really matter for out-of-tree keys like XLA.
73-
// I wonder if we should make "alias dispatch key kernels" a runtime-accessible property on the OperatorHandle?
74-
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);
75-
76-
const auto& schema_returns = op.schema().returns();
77-
const auto& num_returns = schema_returns.size();
78-
auto returns = torch::jit::last(stack, num_returns);
79-
const auto returns_begin = stack->size() - num_returns;
80-
81-
// Step 4: Unwrap each functional output tensor, syncing any pending updates
82-
for (const auto idx : c10::irange(returns.size())) {
83-
if (returns[idx].isTensor()) {
84-
const auto& out_functional = returns[idx].toTensor();
85-
auto out_unwrapped = sync_and_unwrap_functional_output(out_functional);
86-
(*stack)[returns_begin + idx] = c10::IValue(out_unwrapped);
87-
} else if (returns[idx].isTensorList()) {
88-
const auto& out_functional = returns[idx].toTensorList();
89-
auto out_unwrapped = sync_and_unwrap_functional_output(out_functional);
90-
(*stack)[returns_begin + idx] = c10::IValue(out_unwrapped);
91-
}
92-
}
93-
}
94-
95-
96-
#define DECOMPOSE_FUNCTIONAL(op) \
97-
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&decompose_functional>());
98-
99-
10018
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
10119
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
10220

@@ -315,8 +233,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
315233
OP_DECOMPOSE(pad);
316234
OP_DECOMPOSE(_pad_circular);
317235

318-
DECOMPOSE_FUNCTIONAL(block_diag);
319-
320236
// divide, alias for div
321237
OP_DECOMPOSE2(divide, Tensor);
322238
OP_DECOMPOSE2(divide_, Tensor);

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,34 @@ Tensor cat_batching_rule(TensorList tensors, int64_t dim) {
557557
return physical_views[0].getPhysicalToLogicalMap().apply(result);
558558
}
559559

560+
Tensor block_diag_batching_rule(TensorList tensors) {
561+
if (!participatesInCurrentLevel(tensors)) {
562+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
563+
return at::block_diag(tensors);
564+
}
565+
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
566+
auto physical_tensors = fmap(
567+
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
568+
TORCH_INTERNAL_ASSERT(
569+
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
570+
// Implementing this as a dummy for loop for now, since I'm not sure how to do it any better.
571+
// I'm probably not accounting for potentially multiple batched dimensions?
572+
auto bdim = physical_tensors[0].size(0);
573+
std::vector<Tensor> batched_outputs;
574+
batched_outputs.reserve(bdim);
575+
for (const auto& i : c10::irange(bdim)) {
576+
std::vector<Tensor> inputs_for_batch;
577+
inputs_for_batch.reserve(physical_tensors.size());
578+
for (const auto& t : physical_tensors) {
579+
inputs_for_batch.push_back(t[i]);
580+
}
581+
auto out_for_batch = at::block_diag(inputs_for_batch);
582+
batched_outputs.push_back(out_for_batch.unsqueeze(0));
583+
}
584+
auto result = at::cat(batched_outputs);
585+
return physical_views[0].getPhysicalToLogicalMap().apply(result);
586+
}
587+
560588
Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
561589
if (!participatesInCurrentLevel(tensors)) {
562590
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@@ -666,6 +694,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
666694
m.impl("split_with_sizes", split_with_sizes_batching_rule);
667695
m.impl("unbind.int", unbind_batching_rule);
668696
m.impl("cat", cat_batching_rule);
697+
m.impl("block_diag", block_diag_batching_rule);
669698
m.impl("stack", stack_batching_rule);
670699

671700
// still legacy b/c needs special inplace rules

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,6 @@ def get_vjp(cotangents, *primals):
11361136
xfail('_masked.softmin', ''),
11371137
xfail('amax', ''),
11381138
xfail('amin', ''),
1139-
xfail('block_diag', ''),
11401139
xfail('cdist', ''),
11411140
xfail('cholesky', ''),
11421141
xfail('eig', ''),

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3262,7 +3262,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32623262
xfail('linalg.lu_factor_ex', ''),
32633263
xfail('diagflat', ''),
32643264
xfail('special.log_ndtr'),
3265-
xfail('block_diag'), # aten::slice_copy.Tensor hit the vmap fallback which is currently disabled
32663265
xfail('nn.functional.triplet_margin_loss', ''),
32673266
xfail('nn.functional.pdist', ''),
32683267
xfail('scatter_reduce', 'sum'),

0 commit comments

Comments
 (0)