Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ea180d2

Browse files
committed
Fix batch rule for max_pool2d
1 parent c73dd32 commit ea180d2

File tree

4 files changed

+24
-34
lines changed

4 files changed

+24
-34
lines changed

functorch/csrc/BatchRulesPooling.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,36 @@ Tensor max_pool2d_with_indices_backward_plumbing(const Tensor & grad_output, con
6262
return slow_fallback<Tensor>(op, { grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices });
6363
}
6464

65+
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
66+
max_pool2d_with_indices_batch_rule(
67+
const Tensor& self, optional<int64_t> self_bdim,
68+
IntArrayRef kernel_size, IntArrayRef stride,
69+
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
70+
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
71+
TORCH_INTERNAL_ASSERT(logical_rank == 3 || logical_rank == 4);
72+
// Tensor[B, C, H, W] -> just call max_pool2d
73+
if (logical_rank == 3) {
74+
auto self_ = moveBatchDimToFront(self, self_bdim);
75+
auto result = at::max_pool2d_with_indices(
76+
self_, kernel_size, stride, padding, dilation, ceil_mode);
77+
return std::make_tuple(std::move(std::get<0>(result)), 0, std::move(std::get<1>(result)), 0);
78+
}
79+
// Tensor[B, N, C, H, W] -> Tensor[B * N, C, H, W]
80+
auto bdim_size = self.size(*self_bdim);
81+
auto self_ = reshape_dim_into(*self_bdim, 0, self);
82+
auto result = at::max_pool2d_with_indices(
83+
self_, kernel_size, stride, padding, dilation, ceil_mode);
84+
return std::make_tuple(
85+
reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0,
86+
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
87+
}
6588

6689
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
6790
EXISTING_BDIM(_adaptive_avg_pool2d);
6891
EXISTING_BDIM(avg_pool2d);
6992
m.impl("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_plumbing);
7093
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
94+
VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
7195
}
7296

7397
}}

functorch/csrc/BatchingRegistrations.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -94,33 +94,6 @@ static bool participatesInCurrentLevel(TensorList self) {
9494
return false;
9595
}
9696

97-
std::tuple<Tensor,Tensor> max_pool2d_with_indices_batching_rule(
98-
const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride,
99-
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
100-
if (!participatesInCurrentLevel(self)) {
101-
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
102-
return at::max_pool2d_with_indices(
103-
self, kernel_size, stride, padding, dilation, ceil_mode);
104-
}
105-
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
106-
TORCH_INTERNAL_ASSERT(self_physical.tensor().dim() == 5);
107-
108-
auto N = self_physical.tensor().size(0);
109-
auto M = self_physical.tensor().size(1);
110-
auto physical = self_physical.tensor().flatten(0, 1);
111-
112-
auto result = max_pool2d_with_indices_batching_rule(physical,
113-
kernel_size, stride, padding, dilation, ceil_mode);
114-
115-
auto first = std::get<0>(result).unflatten(0, {N, M});
116-
auto second = std::get<1>(result).unflatten(0, {N, M});
117-
118-
first = self_physical.getPhysicalToLogicalMap().apply(first);
119-
second = self_physical.getPhysicalToLogicalMap().apply(second);
120-
return std::make_tuple<Tensor, Tensor>(std::move(first), std::move(second));
121-
}
122-
123-
12497
bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
12598
if (logical_tensor.dim() > 0) {
12699
return false;
@@ -906,9 +879,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
906879
// m.impl("_add_batch_dim", native::_add_batch_dim);
907880
// m.impl("_remove_batch_dim", native::_remove_batch_dim);
908881

909-
m.impl("max_pool2d", at::native::max_pool2d); // composite
910-
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
911-
912882
m.impl("is_complex", native::is_complex);
913883
//
914884
// // inplace operations

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def vjp_of_vjp(*args_and_cotangents):
350350
xfail('nanmean'),
351351
xfail('block_diag'),
352352
xfail('nn.functional.dropout'),
353-
xfail('nn.functional.max_pool2d'),
354353
xfail('nn.functional.nll_loss'),
355354
}))
356355
def test_vmapvjp(self, device, dtype, op):
@@ -523,7 +522,6 @@ def test():
523522
xfail('nanmean'),
524523
xfail('vstack'),
525524
xfail('block_diag'),
526-
xfail('nn.functional.max_pool2d'),
527525
xfail('nn.functional.batch_norm'),
528526
xfail('nn.functional.nll_loss'),
529527
}))

test/test_vmap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3007,7 +3007,6 @@ class TestVmapOperatorsOpInfo(TestCase):
30073007
xfail('svd', device_type='cuda'),
30083008
xfail('linalg.svd', device_type='cuda'),
30093009
xfail('index_put'),
3010-
xfail('nn.functional.max_pool2d'),
30113010
xfail('nn.functional.batch_norm'),
30123011
xfail('nn.functional.nll_loss'),
30133012
})
@@ -3105,7 +3104,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31053104
xfail('vstack'),
31063105
xfail('block_diag'),
31073106
xfail('nn.functional.dropout'),
3108-
xfail('nn.functional.max_pool2d'),
31093107
xfail('nn.functional.conv2d', ''),
31103108
xfail('nn.functional.batch_norm'),
31113109
})

0 commit comments

Comments
 (0)