Skip to content

Commit f5ce614

Browse files
author
Samantha Andow
authored
Fix normal_ and bernoulli (#670)
* normal_fix * fix binomial test
1 parent c17bf9a commit f5ce614

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,18 @@ std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
292292
return std::make_tuple(out, out_bdim);
293293
}
294294

295+
Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
296+
return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
297+
}
298+
295299
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
296300
#define BINARY_RANDOM_POINTWISE(op) \
297-
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
298-
#define BINARY_RANDOM_POINTWISE2(op, overload) \
299-
m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
301+
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
302+
#define BINARY_RANDOM_POINTWISE2(op, overload) \
303+
m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
300304

301305
BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
302-
BINARY_RANDOM_POINTWISE(binomial);
306+
m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
303307
}
304308

305309
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {

test/test_vmap.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3814,7 +3814,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
38143814
lambda t, _: t.random_(**kwargs),
38153815
lambda t, _: t.random_(100, **kwargs),
38163816
lambda t, _: t.random_(-5, 100, **kwargs),
3817-
# lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim
3817+
lambda t, _: t.normal_(**kwargs),
38183818
lambda t, _: t.bernoulli_(**kwargs),
38193819
lambda t, _: t.cauchy_(**kwargs),
38203820
lambda t, _: t.exponential_(**kwargs),
@@ -3851,7 +3851,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
38513851
self.assertEqual(vmap_result, expected)
38523852
else:
38533853
if batched_input != "none":
3854-
passed_expected = passed_expected[0]
3854+
passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work
38553855
expected = op(passed_expected, always_batched)
38563856
self._assert_all_slices_equal(vmap_result)
38573857
for i in range(B0):
@@ -3923,8 +3923,7 @@ def test_random_binary_out_of_place(self, device, use_generator, randomness, bat
39233923
kwargs = {'generator': generator} if use_generator else {}
39243924
ops = [
39253925
lambda t, o, _: torch.normal(t, o, **kwargs),
3926-
# TODO(samdow): fix binomial
3927-
# lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
3926+
lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
39283927
]
39293928

39303929
B0 = 4

0 commit comments

Comments
 (0)