Skip to content

Commit 254d4e6

Browse files
committed
max_pool2d_backward batch rule
1 parent 731e05f commit 254d4e6

File tree

4 files changed

+90
-54
lines changed

4 files changed

+90
-54
lines changed

functorch/csrc/BatchRulesHelper.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,71 @@ inline void boxed_existing_bdim_all_batch_rule(
264264
#define EXISTING_BDIM_ALL_BOXED(op) \
265265
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
266266

267+
template <int64_t feature_rank>
268+
inline void boxed_all_tensors_have_optional_bdim(
269+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
270+
const auto& schema = op.schema();
271+
const auto num_returns = schema.returns().size();
272+
const auto num_arguments = schema.arguments().size();
273+
274+
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
275+
auto maybe_layer = maybeCurrentDynamicLayer();
276+
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
277+
int64_t cur_level = maybe_layer->layerId();
278+
279+
int64_t args_begin = stack->size() - num_arguments;
280+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
281+
SmallVector<int64_t, 5> tensor_pos;
282+
int64_t batch_size;
283+
284+
find_and_unpack_tensors(
285+
stack, num_arguments, cur_level,
286+
&tensor_inputs, &tensor_pos, &batch_size);
287+
288+
optional<bool> is_no_batch_dim_case;
289+
290+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
291+
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
292+
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
293+
const auto logical_rank = rankWithoutBatchDim(value, bdim);
294+
295+
if (!is_no_batch_dim_case.has_value()) {
296+
is_no_batch_dim_case = (logical_rank == feature_rank);
297+
}
298+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
299+
if (!bdim.has_value()) {
300+
bdim = 0;
301+
}
302+
if (*is_no_batch_dim_case) {
303+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
304+
(*stack)[args_begin + tensor_pos[tensor_idx]] = moveBatchDimToFront(value_, bdim);
305+
continue;
306+
}
307+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
308+
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
309+
}
310+
311+
op.callBoxed(stack);
312+
313+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
314+
const auto& ret = (*stack)[idx];
315+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
316+
"This boxed batching rule does not currently support ops that return non-tensor values");
317+
if (*is_no_batch_dim_case) {
318+
(*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
319+
} else {
320+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
321+
}
322+
}
323+
}
324+
325+
// Useful for many NN operators.
326+
// The operator must satisfy the following:
327+
// - All arguments must accept an optional batch dim.
328+
// - All arguments must be the same rank
329+
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
330+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
331+
267332
template <typename A, A a, typename C>
268333
struct ExistingBdimBatchRuleHelper;
269334

@@ -304,5 +369,29 @@ Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>, ExtraArgs... e
304369
return self;
305370
}
306371

372+
inline int64_t get_bdim_size3(
373+
const Tensor& a_value, optional<int64_t> a_bdim,
374+
const Tensor& b_value, optional<int64_t> b_bdim,
375+
const Tensor& c_value, optional<int64_t> c_bdim) {
376+
if (a_bdim)
377+
return a_value.size(*a_bdim);
378+
if (b_bdim)
379+
return b_value.size(*b_bdim);
380+
if (c_bdim)
381+
return c_value.size(*c_bdim);
382+
TORCH_INTERNAL_ASSERT(false);
383+
}
384+
385+
inline int64_t get_bdim_size2(
386+
const Tensor& a_value, optional<int64_t> a_bdim,
387+
const Tensor& b_value, optional<int64_t> b_bdim) {
388+
if (a_bdim)
389+
return a_value.size(*a_bdim);
390+
if (b_bdim)
391+
return b_value.size(*b_bdim);
392+
TORCH_INTERNAL_ASSERT(false);
393+
}
394+
395+
307396
}}
308397

functorch/csrc/BatchRulesPooling.cpp

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,57 +11,6 @@
1111

1212
namespace at { namespace functorch {
1313

14-
std::tuple<Tensor,int64_t> max_pool2d_with_indices_backward_batch_rule(
15-
const Tensor & grad_output, optional<int64_t> grad_output_bdim,
16-
const Tensor & self, optional<int64_t> self_bdim,
17-
IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding,
18-
IntArrayRef dilation, bool ceil_mode,
19-
const Tensor & indices, optional<int64_t> indices_bdim) {
20-
TORCH_INTERNAL_ASSERT(grad_output_bdim && self_bdim && indices_bdim);
21-
22-
auto bdim_size = self.size(*self_bdim);
23-
auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
24-
auto self_ = reshape_dim_into(*self_bdim, 0, self);
25-
auto indices_ = reshape_dim_into(*indices_bdim, 0, indices);
26-
27-
auto result = at::max_pool2d_with_indices_backward(
28-
grad_output_, self_, kernel_size, stride, padding, dilation, ceil_mode,
29-
indices_);
30-
31-
result = reshape_dim_outof(0, bdim_size, result);
32-
return std::make_tuple(result, 0);
33-
}
34-
35-
Tensor max_pool2d_with_indices_backward_plumbing(const Tensor & grad_output, const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, const Tensor & indices) {
36-
auto maybe_layer = maybeCurrentDynamicLayer();
37-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
38-
int64_t cur_level = maybe_layer->layerId();
39-
40-
Tensor grad_output_value;
41-
optional<int64_t> grad_output_bdim;
42-
std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
43-
Tensor self_value;
44-
optional<int64_t> self_bdim;
45-
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
46-
Tensor indices_value;
47-
optional<int64_t> indices_bdim;
48-
std::tie(indices_value, indices_bdim) = unwrapTensorAtLevel(indices, cur_level);
49-
50-
if (self_bdim && grad_output_bdim && indices_bdim) {
51-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
52-
auto result = max_pool2d_with_indices_backward_batch_rule(
53-
grad_output_value, grad_output_bdim,
54-
self_value, self_bdim,
55-
kernel_size, stride, padding, dilation, ceil_mode,
56-
indices_value, indices_bdim);
57-
return makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
58-
}
59-
60-
static auto op = c10::Dispatcher::singleton()
61-
.findSchemaOrThrow("aten::max_pool2d_with_indices_backward", "");
62-
return slow_fallback<Tensor>(op, { grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices });
63-
}
64-
6514
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
6615
max_pool2d_with_indices_batch_rule(
6716
const Tensor& self, optional<int64_t> self_bdim,
@@ -89,9 +38,9 @@ max_pool2d_with_indices_batch_rule(
8938
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
9039
EXISTING_BDIM(_adaptive_avg_pool2d);
9140
EXISTING_BDIM(avg_pool2d);
92-
m.impl("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_plumbing);
9341
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
9442
VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
43+
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(3, max_pool2d_with_indices_backward);
9544
}
9645

9746
}}

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,6 @@ def test_vmapvjp(self, device, dtype, op):
473473
xfail('nn.functional.nll_loss'),
474474
xfail('block_diag'),
475475
xfail('nn.functional.dropout'),
476-
xfail('nn.functional.max_pool2d'),
477476
xfail('nn.functional.batch_norm'),
478477
})
479478
def test_vmapvjp_has_batch_rule(self, device, dtype, op):

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3083,7 +3083,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30833083
xfail('renorm'),
30843084
xfail('repeat_interleave'),
30853085
xfail('resize_as_'),
3086-
xfail('scatter'),
30873086
xfail('take'),
30883087
xfail('take_along_dim'),
30893088
xfail('tensor_split'),

0 commit comments

Comments
 (0)