|
11 | 11 |
|
12 | 12 | namespace at { namespace functorch {
|
13 | 13 |
|
| 14 | +static Tensor reshape_bdim_into_front( |
| 15 | + const Tensor& value, |
| 16 | + optional<int64_t> bdim, |
| 17 | + int64_t batch_size, |
| 18 | + bool is_no_batch_dim_case) { |
| 19 | + auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size); |
| 20 | + if (!bdim.has_value()) { |
| 21 | + bdim = 0; |
| 22 | + } |
| 23 | + if (is_no_batch_dim_case) { |
| 24 | + return moveBatchDimToFront(value_, bdim); |
| 25 | + } |
| 26 | + return reshape_dim_into(*bdim, 0, value_); |
| 27 | +} |
| 28 | + |
| 29 | +// We can't use ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED because the CUDA |
| 30 | +// kernel rightfully assumes that indices is contiguous. |
| 31 | +std::tuple<Tensor,optional<int64_t>> max_pool2d_with_indices_backward_batch_rule( |
| 32 | + const Tensor& gradOutput, optional<int64_t> gradOutput_bdim, |
| 33 | + const Tensor& input, optional<int64_t> input_bdim, |
| 34 | + IntArrayRef kernel_size, |
| 35 | + IntArrayRef stride, |
| 36 | + IntArrayRef padding, |
| 37 | + IntArrayRef dilation, |
| 38 | + bool ceil_mode, |
| 39 | + const Tensor& indices, optional<int64_t> indices_bdim) { |
| 40 | + TORCH_INTERNAL_ASSERT(input_bdim.has_value() ^ !indices_bdim.has_value()); |
| 41 | + const auto bdim_size = get_bdim_size2(gradOutput, gradOutput_bdim, input, input_bdim); |
| 42 | + const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); |
| 43 | + bool chw_case = input_logical_rank == 3; |
| 44 | + |
| 45 | + const auto gradOutput_ = reshape_bdim_into_front(gradOutput, gradOutput_bdim, bdim_size, chw_case); |
| 46 | + const auto input_ = reshape_bdim_into_front(input, input_bdim, bdim_size, chw_case); |
| 47 | + const auto indices_ = reshape_bdim_into_front(indices, indices_bdim, bdim_size, chw_case); |
| 48 | + |
| 49 | + const auto result = at::max_pool2d_with_indices_backward( |
| 50 | + gradOutput_, input_, kernel_size, stride, padding, dilation, ceil_mode, |
| 51 | + // max_pool2d_with_indices rightfully assumes that indices is contiguous |
| 52 | + indices_.contiguous()); |
| 53 | + |
| 54 | + if (chw_case) { |
| 55 | + return std::make_tuple(std::move(result), 0); |
| 56 | + } else { |
| 57 | + return std::make_tuple(reshape_dim_outof(0, bdim_size, result), 0); |
| 58 | + } |
| 59 | +} |
| 60 | + |
14 | 61 | std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
15 | 62 | max_pool2d_with_indices_batch_rule(
|
16 | 63 | const Tensor& self, optional<int64_t> self_bdim,
|
@@ -40,7 +87,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
40 | 87 | EXISTING_BDIM(avg_pool2d);
|
41 | 88 | EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
|
42 | 89 | VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
|
43 |
| - ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(3, max_pool2d_with_indices_backward); |
| 90 | + VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); |
44 | 91 | }
|
45 | 92 |
|
46 | 93 | }}
|
0 commit comments