|
11 | 11 | // NB: most activation functions fit pointwise unary or binary rules.
|
12 | 12 | // These are only the ones that have special batch rules to help with organization
|
13 | 13 | namespace at { namespace functorch {
|
| 14 | +std::tuple<Tensor,optional<int64_t>> |
| 15 | +glu_batch_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t dim) { |
| 16 | + // repeated error message from glu because 0D -> 1D when batched |
| 17 | + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which |
| 18 | + // can't be evenly halved, but give a nicer error message here. |
| 19 | + TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors"); |
| 20 | + |
| 21 | + const auto rank = rankWithoutBatchDim(self, self_bdim); |
| 22 | + const auto dim_ = maybe_wrap_dim(dim, rank) + 1; |
| 23 | + |
| 24 | + const auto self_ = moveBatchDimToFront(self, self_bdim); |
| 25 | + |
| 26 | + const auto res = at::glu(self_, dim_); |
| 27 | + return std::make_tuple(res, 0); |
| 28 | +} |
| 29 | + |
| 30 | +std::tuple<Tensor,optional<int64_t>> glu_backward_batch_rule( |
| 31 | + const Tensor& grad_output, optional<int64_t> grad_output_bdim, |
| 32 | + const Tensor& self, optional<int64_t> self_bdim, int64_t dim) { |
| 33 | + if (self_bdim) { |
| 34 | + // repeated error message from glu because 0D -> 1D when batched |
| 35 | + // this can't pass anyway because a 0-dimensional tensor has "size" 1, which |
| 36 | + // can't be evenly halved, but give a nicer error message here. |
| 37 | + TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors"); |
| 38 | + } |
| 39 | + |
| 40 | + const auto rank = rankWithoutBatchDim(self, self_bdim); |
| 41 | + const auto dim_ = maybe_wrap_dim(dim, rank) + 1; |
| 42 | + |
| 43 | + const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim); |
| 44 | + const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size); |
| 45 | + const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size); |
| 46 | + |
| 47 | + const auto res = at::glu_backward(grad_output_, self_, dim_); |
| 48 | + return std::make_tuple(res, 0); |
| 49 | +} |
| 50 | + |
14 | 51 | std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
|
15 | 52 | const Tensor& input, optional<int64_t> input_bdim,
|
16 | 53 | const Tensor& weight, optional<int64_t> weight_bdim) {
|
@@ -175,6 +212,8 @@ std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_bat
|
175 | 212 | }
|
176 | 213 |
|
177 | 214 | TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
| 215 | + VMAP_SUPPORT(glu_backward, glu_backward_batch_rule); |
| 216 | + VMAP_SUPPORT(glu, glu_batch_rule); |
178 | 217 | VMAP_SUPPORT(prelu, prelu_batch_rule)
|
179 | 218 | VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
|
180 | 219 | }
|
|
0 commit comments