Skip to content

Commit ecb9c3c

Browse files
author
Samantha Andow
authored
Glu batching rule (forward + backward) (#665)
* glu forward * glu backwards
1 parent 7f6dbe9 commit ecb9c3c

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

functorch/csrc/BatchRulesActivation.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,43 @@
1111
// NB: most activation functions fit pointwise unary or binary rules.
1212
// These are only the ones that have special batch rules to help with organization
1313
namespace at { namespace functorch {
14+
std::tuple<Tensor,optional<int64_t>>
15+
glu_batch_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t dim) {
16+
// repeated error message from glu because 0D -> 1D when batched
17+
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
18+
// can't be evenly halved, but give a nicer error message here.
19+
TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
20+
21+
const auto rank = rankWithoutBatchDim(self, self_bdim);
22+
const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
23+
24+
const auto self_ = moveBatchDimToFront(self, self_bdim);
25+
26+
const auto res = at::glu(self_, dim_);
27+
return std::make_tuple(res, 0);
28+
}
29+
30+
std::tuple<Tensor,optional<int64_t>> glu_backward_batch_rule(
31+
const Tensor& grad_output, optional<int64_t> grad_output_bdim,
32+
const Tensor& self, optional<int64_t> self_bdim, int64_t dim) {
33+
if (self_bdim) {
34+
// repeated error message from glu because 0D -> 1D when batched
35+
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
36+
// can't be evenly halved, but give a nicer error message here.
37+
TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
38+
}
39+
40+
const auto rank = rankWithoutBatchDim(self, self_bdim);
41+
const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
42+
43+
const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim);
44+
const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size);
45+
const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size);
46+
47+
const auto res = at::glu_backward(grad_output_, self_, dim_);
48+
return std::make_tuple(res, 0);
49+
}
50+
1451
std::tuple<Tensor,optional<int64_t>> prelu_batch_rule(
1552
const Tensor& input, optional<int64_t> input_bdim,
1653
const Tensor& weight, optional<int64_t> weight_bdim) {
@@ -175,6 +212,8 @@ std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> prelu_backward_bat
175212
}
176213

177214
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
215+
VMAP_SUPPORT(glu_backward, glu_backward_batch_rule);
216+
VMAP_SUPPORT(glu, glu_batch_rule);
178217
VMAP_SUPPORT(prelu, prelu_batch_rule)
179218
VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule)
180219
}

functorch/csrc/BatchRulesUnaryOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
113113
UNARY_POINTWISE_ALL(expm1);
114114
UNARY_POINTWISE_ALL(floor);
115115
UNARY_POINTWISE_ALL(frac);
116-
UNARY_POINTWISE(glu);
117116
UNARY_POINTWISE(isfinite);
118117
UNARY_POINTWISE(isnan);
119118
UNARY_POINTWISE(isinf);

test/test_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,6 @@ def vjp_of_vjp(*args_and_cotangents):
610610
skip('qr'), # Nondetermistic
611611
xfail('_masked.prod'), # calls aten::item
612612
xfail('stft'),
613-
xfail('nn.functional.glu'),
614613
xfail('nn.functional.fractional_max_pool3d'),
615614
xfail('as_strided'),
616615
xfail('nn.functional.fractional_max_pool2d'),
@@ -954,7 +953,6 @@ def test():
954953
xfail('nn.functional.huber_loss'),
955954
xfail('nn.functional.poisson_nll_loss'),
956955
xfail('nn.functional.bilinear'),
957-
xfail('nn.functional.glu'),
958956
xfail('nn.functional.fractional_max_pool3d'),
959957
xfail('as_strided'),
960958
xfail('linalg.solve_triangular'),
@@ -1018,7 +1016,6 @@ def test():
10181016
xfail('masked_select'),
10191017
skip('nn.functional.fractional_max_pool3d'), # generator works on cpu, fails on cuda
10201018
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617
1021-
xfail('nn.functional.glu'),
10221019
xfail('as_strided'),
10231020
skip('nn.functional.fractional_max_pool2d'), # generator works on cpu, fails on cuda
10241021
skip('solve'),

test/test_vmap.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3115,7 +3115,6 @@ class TestVmapOperatorsOpInfo(TestCase):
31153115
xfail('nn.functional.fractional_max_pool2d'),
31163116
xfail('nn.functional.embedding_bag'),
31173117
xfail('nonzero'),
3118-
xfail('nn.functional.glu'),
31193118
xfail('nn.functional.rrelu'), # random?
31203119
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617
31213120
xfail('bernoulli', ''),
@@ -3251,7 +3250,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
32513250
xfail('nn.functional.fractional_max_pool2d'),
32523251
xfail('stft'),
32533252
xfail('linalg.solve_triangular'),
3254-
xfail('nn.functional.glu'),
32553253
xfail('isclose'),
32563254
xfail('nn.functional.fractional_max_pool3d'),
32573255
xfail('nn.functional.bilinear'),

0 commit comments

Comments
 (0)