|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the BSD-style license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#include <functorch/csrc/BatchRulesHelper.h> |
| 8 | +#include <functorch/csrc/PlumbingHelper.h> |
| 9 | +#include <ATen/core/dispatch/Dispatcher.h> |
| 10 | + |
| 11 | +namespace at { namespace functorch { |
| 12 | + |
| 13 | +// convolution_batch_rule translated from jax with modifications: |
| 14 | +// https://github.com/google/jax/blob/master/jax/_src/lax/lax.py#L3143 |
| 15 | + |
| 16 | +// PyTorch's convolution is different from JAX's conv_general_dilated: |
| 17 | +// we do not support batch_group_count (which is needed for convolution backwards). |
| 18 | +// Instead, there's a convolution_backward op that needs a batching rule. |
| 19 | +std::tuple<Tensor,optional<int64_t>> |
| 20 | +convolution_batch_rule(const Tensor& lhs, optional<int64_t> lhs_bdim, const Tensor& rhs, optional<int64_t> rhs_bdim, const optional<Tensor>& bias, optional<int64_t> bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) { |
| 21 | + DimVector lhs_spec(stride.size() + 2); |
| 22 | + std::iota(lhs_spec.begin(), lhs_spec.end(), 0); |
| 23 | + DimVector rhs_spec = lhs_spec; |
| 24 | + DimVector out_spec = lhs_spec; |
| 25 | + if (transposed) { |
| 26 | + rhs_spec[0] = 1; |
| 27 | + rhs_spec[1] = 0; |
| 28 | + } |
| 29 | + |
| 30 | + // If we have a batched bias or weight, we need to perform the computation separately. |
| 31 | + optional<Tensor> unbatched_bias; |
| 32 | + bool separate_bias; |
| 33 | + if ((rhs_bdim && bias && bias->defined()) || bias_bdim) { |
| 34 | + TORCH_INTERNAL_ASSERT(bias.has_value()); |
| 35 | + TORCH_INTERNAL_ASSERT(bias->defined()); |
| 36 | + unbatched_bias = nullopt; |
| 37 | + separate_bias = true; |
| 38 | + } else { |
| 39 | + unbatched_bias = bias; |
| 40 | + separate_bias = false; |
| 41 | + } |
| 42 | + std::tuple<Tensor, optional<int64_t>> result; |
| 43 | + if (lhs_bdim && !rhs_bdim) { |
| 44 | + auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs); |
| 45 | + auto out = at::convolution(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); |
| 46 | + out = reshape_dim_outof(out_spec[0], lhs.sizes()[*lhs_bdim], out); |
| 47 | + result = std::make_tuple(out, out_spec[0]); |
| 48 | + } else if (!lhs_bdim && rhs_bdim) { |
| 49 | + if (groups == 1) { |
| 50 | + auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs); |
| 51 | + auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); |
| 52 | + out = reshape_dim_outof(out_spec[1], rhs.sizes()[*rhs_bdim], out); |
| 53 | + result = std::make_tuple(out, out_spec[1]); |
| 54 | + } else { |
| 55 | + auto dim_with_groups = transposed ? 1 : 0; |
| 56 | + auto new_w = reshape_dim_outof(rhs_spec[dim_with_groups] + (*rhs_bdim <= rhs_spec[0]), groups, rhs); |
| 57 | + new_w = reshape_dim_into(*rhs_bdim + (rhs_spec[0] < rhs_bdim), rhs_spec[0] + 1, new_w); |
| 58 | + new_w = reshape_dim_into(rhs_spec[0], rhs_spec[0], new_w); |
| 59 | + auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); |
| 60 | + out = reshape_dim_outof(out_spec[1], groups, out); |
| 61 | + out = reshape_dim_outof(out_spec[1] + 1, rhs.sizes()[*rhs_bdim], out); |
| 62 | + out = reshape_dim_into(out_spec[1], out_spec[1] + 1, out); |
| 63 | + result = std::make_tuple(out, out_spec[1]); |
| 64 | + } |
| 65 | + } else if (lhs_bdim && rhs_bdim) { |
| 66 | + auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[1], lhs); |
| 67 | + groups *= lhs.sizes()[*lhs_bdim]; |
| 68 | + auto dim_with_groups = transposed ? 1 : 0; |
| 69 | + auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs); |
| 70 | + auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); |
| 71 | + out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out); |
| 72 | + result = std::make_tuple(out, out_spec[1]); |
| 73 | + } else { |
| 74 | + result = std::make_tuple(at::convolution(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt); |
| 75 | + } |
| 76 | + if (separate_bias) { |
| 77 | + auto A = std::get<0>(result); |
| 78 | + auto A_batch_dim = std::get<1>(result); |
| 79 | + auto B = *bias; |
| 80 | + auto B_batch_dim = bias_bdim; |
| 81 | + A = moveBatchDimToFront(A, A_batch_dim); |
| 82 | + B = moveBatchDimToFront(B, B_batch_dim); |
| 83 | + for (size_t i = 0; i < out_spec.size() - 2; i++) { |
| 84 | + B = B.unsqueeze(-1); |
| 85 | + } |
| 86 | + B = maybePadToLogicalRank(B, B_batch_dim, rankWithoutBatchDim(A, A_batch_dim)); |
| 87 | + |
| 88 | + return std::make_tuple(at::add(A, B), 0); |
| 89 | + } else { |
| 90 | + return result; |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +Tensor _convolution_decomp( |
| 95 | + const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt, |
| 96 | + IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, |
| 97 | + bool transposed_, IntArrayRef output_padding_, int64_t groups_, |
| 98 | + bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { |
| 99 | + // Ignore everything. If the user called this in the normal way, |
| 100 | + // then they should be fine. |
| 101 | + (void) benchmark; |
| 102 | + (void) deterministic; |
| 103 | + (void) cudnn_enabled; |
| 104 | + (void) allow_tf32; |
| 105 | + return at::convolution( |
| 106 | + input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_); |
| 107 | +} |
| 108 | + |
| 109 | +// TODO: delete the following after confirming performance |
| 110 | +// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) { |
| 111 | +// if (bdim == 0) { |
| 112 | +// return value.size(1) == 1; |
| 113 | +// } |
| 114 | +// return value.size(0) == 1; |
| 115 | +// } |
| 116 | +// |
| 117 | +// std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule( |
| 118 | +// const Tensor& self, optional<int64_t> self_bdim, |
| 119 | +// const Tensor& grad_output, optional<int64_t> grad_output_bdim, |
| 120 | +// const Tensor& weight, optional<int64_t> weight_bdim, |
| 121 | +// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, |
| 122 | +// bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) { |
| 123 | +// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim); |
| 124 | +// // TODO: No clue if this works if the first non-batch dim isn't size 1 |
| 125 | +// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim)); |
| 126 | +// TORCH_INTERNAL_ASSERT(self.dim() == 5); |
| 127 | +// |
| 128 | +// auto bdim_size = self.size(*self_bdim); |
| 129 | +// auto self_ = reshape_dim_into(*self_bdim, 0, self); |
| 130 | +// auto in_channels = self_.size(1); |
| 131 | +// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); |
| 132 | +// |
| 133 | +// auto grad_self = at::cudnn_convolution_backward_input( |
| 134 | +// self_.sizes(), grad_output_, weight, |
| 135 | +// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); |
| 136 | +// grad_self = reshape_dim_outof(0, bdim_size, grad_self); |
| 137 | +// |
| 138 | +// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py |
| 139 | +// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride); |
| 140 | +// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)}); |
| 141 | +// auto grad_sample = at::einsum("noq,npq->nop", {B, A}); |
| 142 | +// grad_sample = grad_sample.view({ |
| 143 | +// bdim_size, groups, -1, groups, in_channels / groups, |
| 144 | +// weight.size(2) * weight.size(3) }); |
| 145 | +// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample}); |
| 146 | +// grad_sample = grad_sample.reshape( |
| 147 | +// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)}); |
| 148 | +// |
| 149 | +// return std::make_tuple(grad_self, 0, grad_sample, 0); |
| 150 | +// } |
| 151 | +// |
| 152 | +// std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) { |
| 153 | +// auto maybe_layer = maybeCurrentDynamicLayer(); |
| 154 | +// TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); |
| 155 | +// int64_t cur_level = maybe_layer->layerId(); |
| 156 | +// |
| 157 | +// Tensor self_value; |
| 158 | +// optional<int64_t> self_bdim; |
| 159 | +// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); |
| 160 | +// Tensor grad_output_value; |
| 161 | +// optional<int64_t> grad_output_bdim; |
| 162 | +// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level); |
| 163 | +// Tensor weight_value; |
| 164 | +// optional<int64_t> weight_bdim; |
| 165 | +// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); |
| 166 | +// |
| 167 | +// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) { |
| 168 | +// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); |
| 169 | +// auto result = cudnn_conv_per_sample_grad_rule( |
| 170 | +// self_value, self_bdim, |
| 171 | +// grad_output_value, grad_output_bdim, |
| 172 | +// weight_value, weight_bdim, |
| 173 | +// padding, stride, dilation, groups, |
| 174 | +// benchmark, deterministic, allow_tf32, output_mask); |
| 175 | +// return std::make_tuple( |
| 176 | +// makeBatched(std::get<0>(result), std::get<1>(result), cur_level), |
| 177 | +// makeBatched(std::get<2>(result), std::get<3>(result), cur_level)); |
| 178 | +// } |
| 179 | +// |
| 180 | +// static auto op = c10::Dispatcher::singleton() |
| 181 | +// .findSchemaOrThrow("aten::cudnn_convolution_backward", ""); |
| 182 | +// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask }); |
| 183 | +// } |
| 184 | + |
| 185 | +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { |
| 186 | + VMAP_SUPPORT("convolution", convolution_batch_rule); |
| 187 | + m.impl("_convolution", _convolution_decomp); |
| 188 | +} |
| 189 | + |
| 190 | +}} // namespace at;:functorch |
0 commit comments