@@ -113,6 +113,7 @@ def __init__(self, rate, *, validate_args=None):
113113 dist .BernoulliProbs : lambda probs : osp .bernoulli (p = probs ),
114114 dist .BernoulliLogits : lambda logits : osp .bernoulli (p = _to_probs_bernoulli (logits )),
115115 dist .Beta : lambda con1 , con0 : osp .beta (con1 , con0 ),
116+ dist .BetaProportion : lambda mu , kappa : osp .beta (mu * kappa , (1 - mu ) * kappa ),
116117 dist .BinomialProbs : lambda probs , total_count : osp .binom (n = total_count , p = probs ),
117118 dist .BinomialLogits : lambda logits , total_count : osp .binom (
118119 n = total_count , p = _to_probs_bernoulli (logits )
@@ -149,6 +150,10 @@ def __init__(self, rate, *, validate_args=None):
149150 dist .VonMises : lambda loc , conc : osp .vonmises (
150151 loc = np .array (loc , dtype = np .float64 ), kappa = np .array (conc , dtype = np .float64 )
151152 ),
153+ dist .Weibull : lambda scale , conc : osp .weibull_min (
154+ c = conc ,
155+ scale = scale ,
156+ ),
152157 _TruncatedNormal : _truncnorm_to_scipy ,
153158}
154159
@@ -164,6 +169,9 @@ def get_sp_dist(jax_dist):
164169 T (dist .Beta , 0.2 , 1.1 ),
165170 T (dist .Beta , 1.0 , jnp .array ([2.0 , 2.0 ])),
166171 T (dist .Beta , 1.0 , jnp .array ([[1.0 , 1.0 ], [2.0 , 2.0 ]])),
172+ T (dist .BetaProportion , 0.2 , 10.0 ),
173+ T (dist .BetaProportion , 0.51 , jnp .array ([2.0 , 1.0 ])),
174+ T (dist .BetaProportion , 0.5 , jnp .array ([[4.0 , 4.0 ], [2.0 , 2.0 ]])),
167175 T (dist .Chi2 , 2.0 ),
168176 T (dist .Chi2 , jnp .array ([0.3 , 1.3 ])),
169177 T (dist .Cauchy , 0.0 , 1.0 ),
@@ -301,6 +309,9 @@ def get_sp_dist(jax_dist):
301309 T (dist .Uniform , 0.0 , 2.0 ),
302310 T (dist .Uniform , 1.0 , jnp .array ([2.0 , 3.0 ])),
303311 T (dist .Uniform , jnp .array ([0.0 , 0.0 ]), jnp .array ([[2.0 ], [3.0 ]])),
312+ T (dist .Weibull , 0.2 , 1.1 ),
313+ T (dist .Weibull , 2.8 , jnp .array ([2.0 , 2.0 ])),
314+ T (dist .Weibull , 1.8 , jnp .array ([[1.0 , 1.0 ], [2.0 , 2.0 ]])),
304315]
305316
306317DIRECTIONAL = [
@@ -346,6 +357,25 @@ def get_sp_dist(jax_dist):
346357 T (dist .MultinomialProbs , jnp .array ([0.2 , 0.7 , 0.1 ]), 10 ),
347358 T (dist .MultinomialProbs , jnp .array ([0.2 , 0.7 , 0.1 ]), jnp .array ([5 , 8 ])),
348359 T (dist .MultinomialLogits , jnp .array ([- 1.0 , 3.0 ]), jnp .array ([[5 ], [8 ]])),
360+ T (dist .NegativeBinomialProbs , 10 , 0.2 ),
361+ T (dist .NegativeBinomialProbs , 10 , jnp .array ([0.2 , 0.6 ])),
362+ T (dist .NegativeBinomialProbs , jnp .array ([4.2 , 10.7 , 2.1 ]), 0.2 ),
363+ T (
364+ dist .NegativeBinomialProbs ,
365+ jnp .array ([4.2 , 10.7 , 2.1 ]),
366+ jnp .array ([0.2 , 0.6 , 0.5 ]),
367+ ),
368+ T (dist .NegativeBinomialLogits , 10 , - 2.1 ),
369+ T (dist .NegativeBinomialLogits , 10 , jnp .array ([- 5.2 , 2.1 ])),
370+ T (dist .NegativeBinomialLogits , jnp .array ([4.2 , 10.7 , 2.1 ]), - 5.2 ),
371+ T (
372+ dist .NegativeBinomialLogits ,
373+ jnp .array ([4.2 , 7.7 , 2.1 ]),
374+ jnp .array ([4.2 , 0.7 , 2.1 ]),
375+ ),
376+ T (dist .NegativeBinomial2 , 0.3 , 10 ),
377+ T (dist .NegativeBinomial2 , jnp .array ([10.2 , 7 , 31 ]), 10 ),
378+ T (dist .NegativeBinomial2 , jnp .array ([10.2 , 7 , 31 ]), jnp .array ([10.2 , 20.7 , 2.1 ])),
349379 T (dist .OrderedLogistic , - 2 , jnp .array ([- 10.0 , 4.0 , 9.0 ])),
350380 T (dist .OrderedLogistic , jnp .array ([- 4 , 3 , 4 , 5 ]), jnp .array ([- 1.5 ])),
351381 T (dist .Poisson , 2.0 ),
@@ -631,7 +661,7 @@ def fn(args):
631661 # finite diff approximation
632662 expected_grad = (fn_rhs - fn_lhs ) / (2.0 * eps )
633663 assert jnp .shape (actual_grad [i ]) == jnp .shape (repara_params [i ])
634- assert_allclose (jnp .sum (actual_grad [i ]), expected_grad , rtol = 0.02 )
664+ assert_allclose (jnp .sum (actual_grad [i ]), expected_grad , rtol = 0.02 , atol = 0.03 )
635665
636666
637667@pytest .mark .parametrize (
@@ -699,7 +729,7 @@ def log_likelihood(*params):
699729
700730 expected = log_likelihood (* params )
701731 actual = jax .jit (log_likelihood )(* params )
702- assert_allclose (actual , expected , atol = 1e -5 )
732+ assert_allclose (actual , expected , atol = 2e -5 )
703733
704734
705735@pytest .mark .parametrize (
@@ -823,6 +853,8 @@ def test_gof(jax_dist, sp_dist, params):
823853 pytest .xfail ("incorrect submanifold scaling" )
824854
825855 num_samples = 10000
856+ if "BetaProportion" in jax_dist .__name__ :
857+ num_samples = 20000
826858 rng_key = random .PRNGKey (0 )
827859 d = jax_dist (* params )
828860 samples = d .sample (key = rng_key , sample_shape = (num_samples ,))
0 commit comments