|
| 1 | +// Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the BSD-style license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#include <functorch/csrc/BatchRulesHelper.h> |
| 8 | +#include <functorch/csrc/PlumbingHelper.h> |
| 9 | +#include <ATen/Operators.h> |
| 10 | + |
| 11 | +// NB: most activation functions fit pointwise unary or binary rules. |
| 12 | +// These are only the ones that have special batch rules to help with organization |
| 13 | +namespace at { namespace functorch { |
| 14 | +std::tuple<Tensor,optional<int64_t>> prelu_batch_rule( |
| 15 | + const Tensor& input, optional<int64_t> input_bdim, |
| 16 | + const Tensor& weight, optional<int64_t> weight_bdim) { |
| 17 | + if (!weight_bdim && weight.dim() == 0) { |
| 18 | + return std::make_tuple(at::prelu(input, weight), input_bdim); |
| 19 | + } |
| 20 | + |
| 21 | + const auto input_ = moveBatchDimToFront(input, input_bdim); |
| 22 | + auto weight_flatten = moveBatchDimToFront(weight, weight_bdim); |
| 23 | + |
| 24 | + if (weight_flatten.dim() > 1) { |
| 25 | + // for an input [N, C, ...] |
| 26 | + // weight can be a non-vector but the total number of elements must be the same as C |
| 27 | + weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); |
| 28 | + } |
| 29 | + |
| 30 | + const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim); |
| 31 | + VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); |
| 32 | + const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank; |
| 33 | + new_shape.reserve(final_size); |
| 34 | + |
| 35 | + if (weight_flatten.dim() == 2 || !weight_bdim) { |
| 36 | + // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the |
| 37 | + // decomposition, we pad the weight to |
| 38 | + |
| 39 | + // copies checks from prelu if the weight (without vmap) is not a scalar |
| 40 | + TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor."); |
| 41 | + |
| 42 | + int64_t channel_size = 1; // channel_size default to 1 |
| 43 | + if (input_logical_rank > 1) { |
| 44 | + const auto channel_dim = input_bdim ? 2 : 1; |
| 45 | + channel_size = input_.size(channel_dim); |
| 46 | + } |
| 47 | + const auto weight_num = weight_flatten.size(-1); |
| 48 | + TORCH_CHECK(channel_size == weight_num, |
| 49 | + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, |
| 50 | + " and channel size = ", channel_size, "."); |
| 51 | + |
| 52 | + // pads to the left so that the flattened shape matches up with the channel |
| 53 | + if (!weight_bdim) { |
| 54 | + new_shape.insert(new_shape.begin(), 1); |
| 55 | + } else { |
| 56 | + new_shape.insert(new_shape.begin() + 1, 1); |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + for (int64_t i = new_shape.size(); i < final_size; i ++) { |
| 61 | + new_shape.push_back(1); |
| 62 | + } |
| 63 | + TORCH_INTERNAL_ASSERT(new_shape.size() == final_size); |
| 64 | + const auto weight_padded = weight_flatten.view(new_shape); |
| 65 | + auto zero_tensor = at::zeros(1, input.options()); |
| 66 | + |
| 67 | + // decomposes function, |
| 68 | + auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_); |
| 69 | + return std::make_tuple(res, 0); |
| 70 | +} |
| 71 | + |
| 72 | +VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) { |
| 73 | + // helper function that get the size of input, ensuring that there's batch dim, without expanding input |
| 74 | + if (has_bdim) { |
| 75 | + // sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes() |
| 76 | + VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); |
| 77 | + return new_shape; |
| 78 | + } |
| 79 | + VmapDimVector new_shape(1, batch_size); |
| 80 | + new_shape.reserve(input.dim() + 1); |
| 81 | + new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end()); |
| 82 | + return new_shape; |
| 83 | +} |
| 84 | + |
| 85 | +VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) { |
| 86 | + // if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim) |
| 87 | + if (need_bdim) { |
| 88 | + return ensure_shape_with_bdim(input, has_bdim, batch_size); |
| 89 | + } else if (has_bdim) { // !need_bdim && has_bdim |
| 90 | + VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end()); |
| 91 | + return new_shape; |
| 92 | + } else { // !need_bdim && !has_bdim |
| 93 | + VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); |
| 94 | + return new_shape; |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +std::tuple<Tensor, Tensor> prelu_backward_batched( |
| 99 | + const Tensor& grad_out, const Tensor& self, const Tensor& weight, |
| 100 | + const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) { |
| 101 | + // helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones |
| 102 | + const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out); |
| 103 | + const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape); |
| 104 | + const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out); |
| 105 | + const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape); |
| 106 | + const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape); |
| 107 | + return std::make_tuple(input_grad, weight_grad); |
| 108 | +} |
| 109 | + |
| 110 | +std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_batch_rule( |
| 111 | + const Tensor& grad_out, optional<int64_t> grad_out_bdim, |
| 112 | + const Tensor& self, optional<int64_t> self_bdim, |
| 113 | + const Tensor& weight, optional<int64_t> weight_bdim) { |
| 114 | + const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim); |
| 115 | + const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); |
| 116 | + const auto self_ = moveBatchDimToFront(self, self_bdim); |
| 117 | + const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size); |
| 118 | + if (!weight_bdim && weight.dim() == 0) { |
| 119 | + VmapDimVector weight_grad_shape(1, batch_size); |
| 120 | + VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1); |
| 121 | + weight_grad_shape_padded[0] = batch_size; |
| 122 | + const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape); |
| 123 | + return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0); |
| 124 | + } |
| 125 | + const auto weight_ = moveBatchDimToFront(weight, weight_bdim); |
| 126 | + auto weight_flatten = weight_; |
| 127 | + if (weight_flatten.dim() > 1) { |
| 128 | + // for an input [N, C, ...] |
| 129 | + // weight can be a non-vector but the total number of elements must be the same as C |
| 130 | + weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); |
| 131 | + } |
| 132 | + |
| 133 | + const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim); |
| 134 | + VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); |
| 135 | + const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank; |
| 136 | + new_shape.reserve(final_size); |
| 137 | + |
| 138 | + if (weight_flatten.dim() == 2 || !weight_bdim) { |
| 139 | + // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the |
| 140 | + // decomposition, we pad the weight to |
| 141 | + |
| 142 | + // copies checks from prelu if the weight (without vmap) is not a scalar |
| 143 | + TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor."); |
| 144 | + |
| 145 | + int64_t channel_size = 1; // channel_size default to 1 |
| 146 | + if (self_logical_rank > 1) { |
| 147 | + channel_size = self_.size(self_bdim.has_value() ? 2 : 1); |
| 148 | + } |
| 149 | + |
| 150 | + const auto weight_num = weight_flatten.size(-1); |
| 151 | + TORCH_CHECK(channel_size == weight_num, |
| 152 | + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, |
| 153 | + " and channel size = ", channel_size, "."); |
| 154 | + |
| 155 | + // pads to the left so that the flattened shape matches up with the channel |
| 156 | + if (!weight_bdim) { |
| 157 | + new_shape.insert(new_shape.begin(), 1); |
| 158 | + } else { |
| 159 | + new_shape.insert(new_shape.begin() + 1, 1); |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + for (int64_t i = new_shape.size(); i < final_size; i ++) { |
| 164 | + new_shape.push_back(1); |
| 165 | + } |
| 166 | + // weight grad does not depend on weight values. It is batched iff grad_out or self are batched |
| 167 | + const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value(); |
| 168 | + |
| 169 | + const auto weight_padded = weight_flatten.view(new_shape); |
| 170 | + const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size); |
| 171 | + const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size); |
| 172 | + |
| 173 | + const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape); |
| 174 | + return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional<int64_t>(0) : nullopt)); |
| 175 | +} |
| 176 | + |
| 177 | +TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { |
| 178 | + VMAP_SUPPORT(prelu, prelu_batch_rule) |
| 179 | + VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule) |
| 180 | +} |
| 181 | +}} // namespace at::functorch |
0 commit comments