@@ -540,6 +540,7 @@ def get_sp_dist(jax_dist):
540540 T (dist .HalfNormal , 1.0 ),
541541 T (dist .HalfNormal , np .array ([1.0 , 2.0 ])),
542542 T (_ImproperWrapper , constraints .positive , (), (3 ,)),
543+ T (dist .InverseGamma , np .array ([3.1 ]), np .array ([[2.0 ], [3.0 ]])),
543544 T (dist .InverseGamma , np .array ([1.7 ]), np .array ([[2.0 ], [3.0 ]])),
544545 T (dist .InverseGamma , np .array ([0.5 , 1.3 ]), np .array ([[1.0 ], [3.0 ]])),
545546 T (dist .Kumaraswamy , 10.0 , np .array ([2.0 , 3.0 ])),
@@ -1568,7 +1569,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15681569 samples = d .sample (key = random .PRNGKey (0 ), sample_shape = (100 ,))
15691570 quantiles = random .uniform (random .PRNGKey (1 ), (100 ,) + d .shape ())
15701571 try :
1571- rtol = 2e-3 if jax_dist in (dist .Gamma , dist .StudentT ) else 1e-5
1572+ rtol = 2e-3 if jax_dist in (dist .Gamma , dist .LogNormal , dist . StudentT ) else 1e-5
15721573 if d .shape () == () and not d .is_discrete :
15731574 assert_allclose (
15741575 jax .vmap (jax .grad (d .cdf ))(samples ),
@@ -1585,7 +1586,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15851586 assert_allclose (d .cdf (d .icdf (quantiles )), quantiles , atol = 1e-5 , rtol = 1e-5 )
15861587 assert_allclose (d .icdf (d .cdf (samples )), samples , atol = 1e-5 , rtol = rtol )
15871588 except NotImplementedError :
1588- pass
1589+ pytest . skip ( "cdf/icdf not implemented" )
15891590
15901591 # test against scipy
15911592 if not sp_dist :
@@ -1599,7 +1600,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params):
15991600 expected_icdf = sp_dist .ppf (quantiles )
16001601 assert_allclose (actual_icdf , expected_icdf , atol = 1e-4 , rtol = 1e-4 )
16011602 except NotImplementedError :
1602- pass
1603+ pytest . skip ( "cdf/icdf not implemented" )
16031604
16041605
16051606@pytest .mark .parametrize ("jax_dist, sp_dist, params" , CONTINUOUS + DIRECTIONAL )
0 commit comments