Skip to content

Commit c17bf9a

Browse files
committed
silu batch rule
1 parent cca6486 commit c17bf9a

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
422422
BINARY_POINTWISE(softshrink_backward);
423423
BINARY_POINTWISE(tanh_backward);
424424
BINARY_POINTWISE(threshold_backward);
425+
BINARY_POINTWISE(silu_backward);
425426

426427
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
427428
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;

test/test_vmap.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,18 @@ def test_copy_(self):
14201420
with self.assertRaisesRegex(RuntimeError, 'inplace'):
14211421
vmap(Tensor.copy_, in_dims=(None, 0))(x, y)
14221422

1423+
def test_silu_backward(self):
1424+
test = self._vmap_test
1425+
device = 'cpu'
1426+
getter = TensorFactory.randp1
1427+
B0 = 7
1428+
op = torch.ops.aten.silu_backward
1429+
1430+
# Single vmap: op(Tensor, Tensor)
1431+
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1432+
test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0))
1433+
test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None))
1434+
14231435
@parametrize('case', [
14241436
subtest(_make_case(torch.add), name='add'),
14251437
subtest(_make_case(lambda x, y: x + y), name='add_dunder'),

0 commit comments

Comments
 (0)