@@ -3814,7 +3814,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
3814
3814
lambda t , _ : t .random_ (** kwargs ),
3815
3815
lambda t , _ : t .random_ (100 , ** kwargs ),
3816
3816
lambda t , _ : t .random_ (- 5 , 100 , ** kwargs ),
3817
- # lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim
3817
+ lambda t , _ : t .normal_ (** kwargs ),
3818
3818
lambda t , _ : t .bernoulli_ (** kwargs ),
3819
3819
lambda t , _ : t .cauchy_ (** kwargs ),
3820
3820
lambda t , _ : t .exponential_ (** kwargs ),
@@ -3851,7 +3851,7 @@ def test_random_unary_inplace(self, device, use_generator, randomness, batched_i
3851
3851
self .assertEqual (vmap_result , expected )
3852
3852
else :
3853
3853
if batched_input != "none" :
3854
- passed_expected = passed_expected [0 ]
3854
+ passed_expected = passed_expected [0 ]. clone () # bug in pytorch, normal_ on views doesn't work
3855
3855
expected = op (passed_expected , always_batched )
3856
3856
self ._assert_all_slices_equal (vmap_result )
3857
3857
for i in range (B0 ):
@@ -3923,8 +3923,7 @@ def test_random_binary_out_of_place(self, device, use_generator, randomness, bat
3923
3923
kwargs = {'generator' : generator } if use_generator else {}
3924
3924
ops = [
3925
3925
lambda t , o , _ : torch .normal (t , o , ** kwargs ),
3926
- # TODO(samdow): fix binomial
3927
- # lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
3926
+ lambda t , o , _ : torch .binomial (t , (o - 0.5 ), ** kwargs ),
3928
3927
]
3929
3928
3930
3929
B0 = 4
0 commit comments