Skip to content

Commit 93c8575

Browse files
authored
Fix max_pool2d batch rule (#202)
1 parent 8a96f2b commit 93c8575

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

functorch/csrc/BatchRulesPooling.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,53 @@
1111

1212
namespace at { namespace functorch {
1313

14+
static Tensor reshape_bdim_into_front(
15+
const Tensor& value,
16+
optional<int64_t> bdim,
17+
int64_t batch_size,
18+
bool is_no_batch_dim_case) {
19+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
20+
if (!bdim.has_value()) {
21+
bdim = 0;
22+
}
23+
if (is_no_batch_dim_case) {
24+
return moveBatchDimToFront(value_, bdim);
25+
}
26+
return reshape_dim_into(*bdim, 0, value_);
27+
}
28+
29+
// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA
30+
// kernel rightfully assumes that indices is contiguous.
31+
std::tuple<Tensor,optional<int64_t>> max_pool2d_with_indices_backward_batch_rule(
32+
const Tensor& gradOutput, optional<int64_t> gradOutput_bdim,
33+
const Tensor& input, optional<int64_t> input_bdim,
34+
IntArrayRef kernel_size,
35+
IntArrayRef stride,
36+
IntArrayRef padding,
37+
IntArrayRef dilation,
38+
bool ceil_mode,
39+
const Tensor& indices, optional<int64_t> indices_bdim) {
40+
TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value());
41+
const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim);
42+
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
43+
bool chw_case = input_logical_rank == 3;
44+
45+
const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, chw_case);
46+
const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, chw_case);
47+
const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, chw_case);
48+
49+
const auto result = at::max_pool2d_with_indices_backward(
50+
gradOutput_, input_, kernel_size, stride, padding, dilation, ceil_mode,
51+
// max_pool2d_with_indices rightfully assumes that indices is contiguous
52+
indices_.contiguous());
53+
54+
if (chw_case) {
55+
return std::make_tuple(std::move(result), 0);
56+
} else {
57+
return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0);
58+
}
59+
}
60+
1461
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
1562
max_pool2d_with_indices_batch_rule(
1663
const Tensor& self, optional<int64_t> self_bdim,
@@ -40,7 +87,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
4087
EXISTING_BDIM(avg_pool2d);
4188
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
4289
VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
43-
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(3, max_pool2d_with_indices_backward);
90+
VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule);
4491
}
4592

4693
}}

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def vjp_of_vjp(*args_and_cotangents):
350350
xfail('block_diag'),
351351
xfail('nn.functional.dropout'),
352352
xfail('nn.functional.nll_loss'),
353-
xfail('nn.functional.max_pool2d', device_type='cuda'),
354353
}))
355354
def test_vmapvjp(self, device, dtype, op):
356355
# These are too annoying to put into the list above

0 commit comments

Comments
 (0)