77from numpy .testing import assert_allclose
88import pytest
99
10- from jax import jacobian , lax , random , vmap
10+ from jax import lax , random , vmap
1111import jax .numpy as np
1212from jax .scipy .special import expit , xlog1py , xlogy
1313
1414from numpyro .distributions .util import (
1515 binary_cross_entropy_with_logits ,
1616 categorical ,
1717 cholesky_update ,
18- cumprod ,
19- cumsum ,
2018 multinomial ,
2119 poisson ,
2220 vec_to_tril_matrix ,
@@ -33,36 +31,6 @@ def test_binary_cross_entropy_with_logits(x, y):
3331 assert_allclose (actual , expect , rtol = 1e-6 )
3432
3533
36- @pytest .mark .parametrize ('shape' , [
37- (3 ,),
38- (5 , 3 ),
39- ])
40- def test_cumsum_jac (shape ):
41- rng_key = random .PRNGKey (0 )
42- x = random .normal (rng_key , shape = shape )
43-
44- def test_fn (x ):
45- return np .stack ([x [..., 0 ], x [..., 0 ] + x [..., 1 ], x [..., 0 ] + x [..., 1 ] + x [..., 2 ]], - 1 )
46-
47- assert_allclose (cumsum (x ), test_fn (x ))
48- assert_allclose (jacobian (cumsum )(x ), jacobian (test_fn )(x ))
49-
50-
51- @pytest .mark .parametrize ('shape' , [
52- (3 ,),
53- (5 , 3 ),
54- ])
55- def test_cumprod_jac (shape ):
56- rng_key = random .PRNGKey (0 )
57- x = random .uniform (rng_key , shape = shape )
58-
59- def test_fn (x ):
60- return np .stack ([x [..., 0 ], x [..., 0 ] * x [..., 1 ], x [..., 0 ] * x [..., 1 ] * x [..., 2 ]], - 1 )
61-
62- assert_allclose (cumprod (x ), test_fn (x ))
63- assert_allclose (jacobian (cumprod )(x ), jacobian (test_fn )(x ), atol = 1e-7 )
64-
65-
6634@pytest .mark .parametrize ('prim' , [
6735 xlogy ,
6836 xlog1py ,
@@ -86,19 +54,6 @@ def test_binop_batch_rule(prim):
8654 assert_allclose (actual_bx_y [i ], prim (bx [i ], y ))
8755
8856
89- @pytest .mark .parametrize ('prim' , [
90- cumsum ,
91- cumprod ,
92- ])
93- def test_unop_batch_rule (prim ):
94- rng_key = random .PRNGKey (0 )
95- bx = random .normal (rng_key , (3 , 5 ))
96-
97- actual = vmap (prim )(bx )
98- for i in range (3 ):
99- assert_allclose (actual [i ], prim (bx [i ]))
100-
101-
10257@pytest .mark .parametrize ('p, shape' , [
10358 (np .array ([0.1 , 0.9 ]), ()),
10459 (np .array ([0.2 , 0.8 ]), (2 ,)),
0 commit comments