@@ -50,24 +50,22 @@ def test_random_bernoulli(size, p):
5050
5151
5252@pytest .mark .parametrize (
53- "size,n,p" ,
53+ "size,n,p,update " ,
5454 [
55- (None , 10 , 0.5 ),
56- ((1000 ,), 10 , 0.5 ),
57- ((1000 , 4 ), 10 , 0.5 ),
58- ((1000 , 2 ), np .array ([10 , 40 ]), np .array ([0.5 , 0.3 ])),
55+ ((1000 ,), 10 , 0.5 , False ),
56+ ((1000 , 4 ), 10 , 0.5 , False ),
57+ ((1000 , 2 ), np .array ([10 , 40 ]), np .array ([0.5 , 0.3 ]), True ),
5958 ],
6059)
61- def test_binomial (n , p , size ):
60+ def test_binomial (size , n , p , update ):
6261 rng = shared (np .random .default_rng (123 ))
63- g = pt .random .binomial (n , p , size = size , rng = rng )
64- g_fn = function ([], g , mode = pytorch_mode )
62+ rv = pt .random .binomial (n , p , size = size , rng = rng )
63+ next_rng , * _ = rv .owner .inputs
64+ g_fn = function (
65+ [], rv , mode = pytorch_mode , updates = {rng : next_rng } if update else None
66+ )
6567 samples = g_fn ()
66- if size :
67- np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
68- np .testing .assert_allclose (
69- samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.2
70- )
71- else :
72- ...
73- # TODO: define test
68+ if not update :
69+ np .testing .assert_allclose (samples , g_fn (), rtol = 0.1 )
70+ np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
71+ np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.2 )
0 commit comments