Skip to content

Commit 9404041

Browse files
authored
More distributions (#1039)
* Add Weibull and Betaproportion * Add NegativeBinomials * Add ZeroInflatedNegativeBinomial * Fix some unit tests and bugs * Reviews * Simplify Weibull * Fix Weibull constraint * Improve computation of log_prob for NBLogits * Add dispatch method for NegativeBinomial; fix unit tests * Add distributions to doc
1 parent 155524a commit 9404041

File tree

5 files changed

+253
-5
lines changed

5 files changed

+253
-5
lines changed

docs/source/distributions.rst

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ Beta
7777
:show-inheritance:
7878
:member-order: bysource
7979

80+
BetaProportion
81+
--------------
82+
.. autoclass:: numpyro.distributions.continuous.BetaProportion
83+
:members:
84+
:undoc-members:
85+
:show-inheritance:
86+
:member-order: bysource
87+
88+
8089
Cauchy
8190
------
8291
.. autoclass:: numpyro.distributions.continuous.Cauchy
@@ -253,6 +262,14 @@ Uniform
253262
:show-inheritance:
254263
:member-order: bysource
255264

265+
Weibull
266+
-------
267+
.. autoclass:: numpyro.distributions.continuous.Weibull
268+
:members:
269+
:undoc-members:
270+
:show-inheritance:
271+
:member-order: bysource
272+
256273

257274
Discrete Distributions
258275
======================
@@ -389,6 +406,34 @@ OrderedLogistic
389406
:show-inheritance:
390407
:member-order: bysource
391408

409+
NegativeBinomial
410+
----------------
411+
.. autofunction:: numpyro.distributions.conjugate.NegativeBinomial
412+
413+
NegativeBinomialLogits
414+
----------------------
415+
.. autoclass:: numpyro.distributions.conjugate.NegativeBinomialLogits
416+
:members:
417+
:undoc-members:
418+
:show-inheritance:
419+
:member-order: bysource
420+
421+
NegativeBinomialProbs
422+
---------------------
423+
.. autoclass:: numpyro.distributions.conjugate.NegativeBinomialProbs
424+
:members:
425+
:undoc-members:
426+
:show-inheritance:
427+
:member-order: bysource
428+
429+
NegativeBinomial2
430+
-----------------
431+
.. autoclass:: numpyro.distributions.conjugate.NegativeBinomial2
432+
:members:
433+
:undoc-members:
434+
:show-inheritance:
435+
:member-order: bysource
436+
392437
Poisson
393438
-------
394439
.. autoclass:: numpyro.distributions.discrete.Poisson
@@ -417,6 +462,10 @@ ZeroInflatedPoisson
417462
:show-inheritance:
418463
:member-order: bysource
419464

465+
ZeroInflatedNegativeBinomial2
466+
-----------------------------
467+
.. autofunction:: numpyro.distributions.conjugate.ZeroInflatedNegativeBinomial2
468+
420469

421470
Directional Distributions
422471
=========================

numpyro/distributions/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
BetaBinomial,
66
DirichletMultinomial,
77
GammaPoisson,
8+
NegativeBinomial2,
9+
NegativeBinomialLogits,
10+
NegativeBinomialProbs,
11+
ZeroInflatedNegativeBinomial2,
812
)
913
from numpyro.distributions.continuous import (
1014
LKJ,
1115
Beta,
16+
BetaProportion,
1217
Cauchy,
1318
Chi2,
1419
Dirichlet,
@@ -30,6 +35,7 @@
3035
SoftLaplace,
3136
StudentT,
3237
Uniform,
38+
Weibull,
3339
)
3440
from numpyro.distributions.directional import ProjectedNormal, VonMises
3541
from numpyro.distributions.discrete import (
@@ -88,6 +94,7 @@
8894
"BernoulliProbs",
8995
"Beta",
9096
"BetaBinomial",
97+
"BetaProportion",
9198
"Binomial",
9299
"BinomialLogits",
93100
"BinomialProbs",
@@ -127,6 +134,9 @@
127134
"MultivariateNormal",
128135
"LowRankMultivariateNormal",
129136
"Normal",
137+
"NegativeBinomialProbs",
138+
"NegativeBinomialLogits",
139+
"NegativeBinomial2",
130140
"OrderedLogistic",
131141
"Pareto",
132142
"Poisson",
@@ -144,6 +154,8 @@
144154
"Uniform",
145155
"Unit",
146156
"VonMises",
157+
"Weibull",
147158
"ZeroInflatedDistribution",
148159
"ZeroInflatedPoisson",
160+
"ZeroInflatedNegativeBinomial2",
149161
]

numpyro/distributions/conjugate.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from jax import lax, random
4+
from jax import lax, nn, random
55
import jax.numpy as jnp
6-
from jax.scipy.special import betaln, gammaln
6+
from jax.scipy.special import betainc, betaln, gammaln
77

88
from numpyro.distributions import constraints
99
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
10-
from numpyro.distributions.discrete import BinomialProbs, MultinomialProbs, Poisson
10+
from numpyro.distributions.discrete import (
11+
BinomialProbs,
12+
MultinomialProbs,
13+
Poisson,
14+
ZeroInflatedDistribution,
15+
)
1116
from numpyro.distributions.distribution import Distribution
1217
from numpyro.distributions.util import is_prng_key, promote_shapes, validate_sample
1318

@@ -209,3 +214,79 @@ def mean(self):
209214
@property
210215
def variance(self):
211216
return self.concentration / jnp.square(self.rate) * (1 + self.rate)
217+
218+
def cdf(self, value):
219+
bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
220+
return bt
221+
222+
223+
def NegativeBinomial(total_count, probs=None, logits=None, validate_args=None):
224+
if probs is not None:
225+
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
226+
elif logits is not None:
227+
return NegativeBinomialLogits(total_count, logits, validate_args=validate_args)
228+
else:
229+
raise ValueError("One of `probs` or `logits` must be specified.")
230+
231+
232+
class NegativeBinomialProbs(GammaPoisson):
233+
arg_constraints = {
234+
"total_count": constraints.positive,
235+
"probs": constraints.unit_interval,
236+
}
237+
support = constraints.nonnegative_integer
238+
239+
def __init__(self, total_count, probs, validate_args=None):
240+
self.total_count, self.probs = promote_shapes(total_count, probs)
241+
concentration = total_count
242+
rate = 1.0 / probs - 1.0
243+
super().__init__(concentration, rate, validate_args=validate_args)
244+
245+
246+
class NegativeBinomialLogits(GammaPoisson):
247+
arg_constraints = {
248+
"total_count": constraints.positive,
249+
"logits": constraints.real,
250+
}
251+
support = constraints.nonnegative_integer
252+
253+
def __init__(self, total_count, logits, validate_args=None):
254+
self.total_count, self.logits = promote_shapes(total_count, logits)
255+
concentration = total_count
256+
rate = jnp.exp(-logits)
257+
super().__init__(concentration, rate, validate_args=validate_args)
258+
259+
@validate_sample
260+
def log_prob(self, value):
261+
return -(
262+
self.total_count * nn.softplus(self.logits)
263+
+ value * nn.softplus(-self.logits)
264+
+ _log_beta_1(self.total_count, value)
265+
)
266+
267+
268+
class NegativeBinomial2(GammaPoisson):
269+
"""
270+
Another parameterization of GammaPoisson with `rate` is replaced by `mean`.
271+
"""
272+
273+
arg_constraints = {
274+
"mean": constraints.positive,
275+
"concentration": constraints.positive,
276+
}
277+
support = constraints.nonnegative_integer
278+
279+
def __init__(self, mean, concentration, validate_args=None):
280+
rate = concentration / mean
281+
super().__init__(concentration, rate, validate_args=validate_args)
282+
283+
284+
def ZeroInflatedNegativeBinomial2(
285+
mean, concentration, *, gate=None, gate_logits=None, validate_args=None
286+
):
287+
return ZeroInflatedDistribution(
288+
NegativeBinomial2(mean, concentration, validate_args=validate_args),
289+
gate=gate,
290+
gate_logits=gate_logits,
291+
validate_args=validate_args,
292+
)

numpyro/distributions/continuous.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,3 +1426,77 @@ def infer_shapes(low=(), high=()):
14261426
batch_shape = lax.broadcast_shapes(low, high)
14271427
event_shape = ()
14281428
return batch_shape, event_shape
1429+
1430+
1431+
class Weibull(Distribution):
1432+
arg_constraints = {
1433+
"scale": constraints.positive,
1434+
"concentration": constraints.positive,
1435+
}
1436+
support = constraints.positive
1437+
reparametrized_params = ["scale", "concentration"]
1438+
1439+
def __init__(self, scale, concentration, validate_args=None):
1440+
self.concentration, self.scale = promote_shapes(concentration, scale)
1441+
batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(scale))
1442+
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
1443+
1444+
def sample(self, key, sample_shape=()):
1445+
assert is_prng_key(key)
1446+
return random.weibull_min(
1447+
key,
1448+
scale=self.scale,
1449+
concentration=self.concentration,
1450+
shape=sample_shape + self.batch_shape,
1451+
)
1452+
1453+
@validate_sample
1454+
def log_prob(self, value):
1455+
ll = -jnp.power(value / self.scale, self.concentration)
1456+
ll += jnp.log(self.concentration)
1457+
ll += (self.concentration - 1.0) * jnp.log(value)
1458+
ll -= self.concentration * jnp.log(self.scale)
1459+
return ll
1460+
1461+
def cdf(self, value):
1462+
return 1 - jnp.exp(-((value / self.scale) ** self.concentration))
1463+
1464+
@property
1465+
def mean(self):
1466+
return self.scale * jnp.exp(gammaln(1.0 + 1.0 / self.concentration))
1467+
1468+
@property
1469+
def variance(self):
1470+
return self.scale ** 2 * (
1471+
jnp.exp(gammaln(1.0 + 2.0 / self.concentration))
1472+
- jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2
1473+
)
1474+
1475+
1476+
class BetaProportion(Beta):
1477+
"""
1478+
The BetaProportion distribution is a reparameterization of the conventional
1479+
Beta distribution in terms of a the variate mean and a
1480+
precision parameter.
1481+
1482+
**Reference:**
1483+
`Beta regression for modelling rates and proportion`, Ferrari Silvia, and
1484+
Francisco Cribari-Neto. Journal of Applied Statistics 31.7 (2004): 799-815.
1485+
"""
1486+
1487+
arg_constraints = {
1488+
"mean": constraints.unit_interval,
1489+
"concentration": constraints.positive,
1490+
}
1491+
reparametrized_params = ["mean", "concentration"]
1492+
support = constraints.unit_interval
1493+
1494+
def __init__(self, mean, concentration, validate_args=None):
1495+
self.concentration = jnp.broadcast_to(
1496+
concentration, lax.broadcast_shapes(jnp.shape(concentration))
1497+
)
1498+
super().__init__(
1499+
mean * concentration,
1500+
(1.0 - mean) * concentration,
1501+
validate_args=validate_args,
1502+
)

test/test_distributions.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self, rate, *, validate_args=None):
113113
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
114114
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
115115
dist.Beta: lambda con1, con0: osp.beta(con1, con0),
116+
dist.BetaProportion: lambda mu, kappa: osp.beta(mu * kappa, (1 - mu) * kappa),
116117
dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
117118
dist.BinomialLogits: lambda logits, total_count: osp.binom(
118119
n=total_count, p=_to_probs_bernoulli(logits)
@@ -149,6 +150,10 @@ def __init__(self, rate, *, validate_args=None):
149150
dist.VonMises: lambda loc, conc: osp.vonmises(
150151
loc=np.array(loc, dtype=np.float64), kappa=np.array(conc, dtype=np.float64)
151152
),
153+
dist.Weibull: lambda scale, conc: osp.weibull_min(
154+
c=conc,
155+
scale=scale,
156+
),
152157
_TruncatedNormal: _truncnorm_to_scipy,
153158
}
154159

@@ -164,6 +169,9 @@ def get_sp_dist(jax_dist):
164169
T(dist.Beta, 0.2, 1.1),
165170
T(dist.Beta, 1.0, jnp.array([2.0, 2.0])),
166171
T(dist.Beta, 1.0, jnp.array([[1.0, 1.0], [2.0, 2.0]])),
172+
T(dist.BetaProportion, 0.2, 10.0),
173+
T(dist.BetaProportion, 0.51, jnp.array([2.0, 1.0])),
174+
T(dist.BetaProportion, 0.5, jnp.array([[4.0, 4.0], [2.0, 2.0]])),
167175
T(dist.Chi2, 2.0),
168176
T(dist.Chi2, jnp.array([0.3, 1.3])),
169177
T(dist.Cauchy, 0.0, 1.0),
@@ -301,6 +309,9 @@ def get_sp_dist(jax_dist):
301309
T(dist.Uniform, 0.0, 2.0),
302310
T(dist.Uniform, 1.0, jnp.array([2.0, 3.0])),
303311
T(dist.Uniform, jnp.array([0.0, 0.0]), jnp.array([[2.0], [3.0]])),
312+
T(dist.Weibull, 0.2, 1.1),
313+
T(dist.Weibull, 2.8, jnp.array([2.0, 2.0])),
314+
T(dist.Weibull, 1.8, jnp.array([[1.0, 1.0], [2.0, 2.0]])),
304315
]
305316

306317
DIRECTIONAL = [
@@ -346,6 +357,25 @@ def get_sp_dist(jax_dist):
346357
T(dist.MultinomialProbs, jnp.array([0.2, 0.7, 0.1]), 10),
347358
T(dist.MultinomialProbs, jnp.array([0.2, 0.7, 0.1]), jnp.array([5, 8])),
348359
T(dist.MultinomialLogits, jnp.array([-1.0, 3.0]), jnp.array([[5], [8]])),
360+
T(dist.NegativeBinomialProbs, 10, 0.2),
361+
T(dist.NegativeBinomialProbs, 10, jnp.array([0.2, 0.6])),
362+
T(dist.NegativeBinomialProbs, jnp.array([4.2, 10.7, 2.1]), 0.2),
363+
T(
364+
dist.NegativeBinomialProbs,
365+
jnp.array([4.2, 10.7, 2.1]),
366+
jnp.array([0.2, 0.6, 0.5]),
367+
),
368+
T(dist.NegativeBinomialLogits, 10, -2.1),
369+
T(dist.NegativeBinomialLogits, 10, jnp.array([-5.2, 2.1])),
370+
T(dist.NegativeBinomialLogits, jnp.array([4.2, 10.7, 2.1]), -5.2),
371+
T(
372+
dist.NegativeBinomialLogits,
373+
jnp.array([4.2, 7.7, 2.1]),
374+
jnp.array([4.2, 0.7, 2.1]),
375+
),
376+
T(dist.NegativeBinomial2, 0.3, 10),
377+
T(dist.NegativeBinomial2, jnp.array([10.2, 7, 31]), 10),
378+
T(dist.NegativeBinomial2, jnp.array([10.2, 7, 31]), jnp.array([10.2, 20.7, 2.1])),
349379
T(dist.OrderedLogistic, -2, jnp.array([-10.0, 4.0, 9.0])),
350380
T(dist.OrderedLogistic, jnp.array([-4, 3, 4, 5]), jnp.array([-1.5])),
351381
T(dist.Poisson, 2.0),
@@ -631,7 +661,7 @@ def fn(args):
631661
# finite diff approximation
632662
expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps)
633663
assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
634-
assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02)
664+
assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02, atol=0.03)
635665

636666

637667
@pytest.mark.parametrize(
@@ -699,7 +729,7 @@ def log_likelihood(*params):
699729

700730
expected = log_likelihood(*params)
701731
actual = jax.jit(log_likelihood)(*params)
702-
assert_allclose(actual, expected, atol=1e-5)
732+
assert_allclose(actual, expected, atol=2e-5)
703733

704734

705735
@pytest.mark.parametrize(
@@ -823,6 +853,8 @@ def test_gof(jax_dist, sp_dist, params):
823853
pytest.xfail("incorrect submanifold scaling")
824854

825855
num_samples = 10000
856+
if "BetaProportion" in jax_dist.__name__:
857+
num_samples = 20000
826858
rng_key = random.PRNGKey(0)
827859
d = jax_dist(*params)
828860
samples = d.sample(key=rng_key, sample_shape=(num_samples,))

0 commit comments

Comments
 (0)