Skip to content

Commit 0c129d2

Browse files
authored
Add Gumbel distribution. (#550)
* Add Gumbel distribution. * Fix algebra * Add missing - sign * Add autogenerated documentation for Gumbel * Add Kolmogorov-Smirnov test of samples from distributions for those we do not directly get from upstream
1 parent f2aefcf commit 0c129d2

File tree

4 files changed

+71
-0
lines changed

4 files changed

+71
-0
lines changed

docs/source/distributions.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ Gamma
8585
:show-inheritance:
8686
:member-order: bysource
8787

88+
Gumbel
89+
-----
90+
.. autoclass:: numpyro.distributions.continuous.Gumbel
91+
:members:
92+
:undoc-members:
93+
:show-inheritance:
94+
:member-order: bysource
95+
8896
GaussianRandomWalk
8997
------------------
9098
.. autoclass:: numpyro.distributions.continuous.GaussianRandomWalk

numpyro/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Exponential,
1313
Gamma,
1414
GaussianRandomWalk,
15+
Gumbel,
1516
HalfCauchy,
1617
HalfNormal,
1718
InverseGamma,
@@ -73,6 +74,7 @@
7374
'Gamma',
7475
'GammaPoisson',
7576
'GaussianRandomWalk',
77+
'Gumbel',
7678
'HalfCauchy',
7779
'HalfNormal',
7880
'Independent',

numpyro/distributions/continuous.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
from numpyro.util import copy_docs_from
4949

5050

51+
EULER_MASCHERONI = 0.5772156649015328606065120900824024310421
52+
53+
5154
@copy_docs_from(Distribution)
5255
class Beta(Distribution):
5356
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
@@ -326,6 +329,39 @@ def variance(self):
326329
return np.where(self.concentration <= 2, np.inf, a)
327330

328331

332+
@copy_docs_from(Distribution)
333+
class Gumbel(Distribution):
334+
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
335+
support = constraints.real
336+
reparametrized_params = ['loc', 'scale']
337+
338+
def __init__(self, loc=0., scale=1., validate_args=None):
339+
self.loc, self.scale = promote_shapes(loc, scale)
340+
batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
341+
342+
super(Gumbel, self).__init__(batch_shape=batch_shape,
343+
validate_args=validate_args)
344+
345+
def sample(self, key, sample_shape=()):
346+
standard_gumbel_sample = random.gumbel(key, shape=sample_shape + self.batch_shape + self.event_shape)
347+
return self.loc + self.scale * standard_gumbel_sample
348+
349+
@validate_sample
350+
def log_prob(self, value):
351+
z = (value - self.loc) / self.scale
352+
return -(z + np.exp(-z)) - np.log(self.scale)
353+
354+
@property
355+
def mean(self):
356+
return np.broadcast_to(self.loc + self.scale * EULER_MASCHERONI,
357+
self.batch_shape)
358+
359+
@property
360+
def variance(self):
361+
return np.broadcast_to(np.pi**2 / 6. * self.scale**2,
362+
self.batch_shape)
363+
364+
329365
@copy_docs_from(Distribution)
330366
class LKJ(TransformedDistribution):
331367
r"""

test/test_distributions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag):
6868
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
6969
dist.Exponential: lambda rate: osp.expon(scale=np.reciprocal(rate)),
7070
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1./rate),
71+
dist.Gumbel: lambda loc, scale: osp.gumbel_r(loc=loc, scale=scale),
7172
dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
7273
dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
7374
dist.InverseGamma: lambda conc, rate: osp.invgamma(conc, scale=rate),
@@ -103,6 +104,9 @@ def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag):
103104
T(dist.Gamma, np.array([0.5, 1.3]), np.array([[1.], [3.]])),
104105
T(dist.GaussianRandomWalk, 0.1, 10),
105106
T(dist.GaussianRandomWalk, np.array([0.1, 0.3, 0.25]), 10),
107+
T(dist.Gumbel, 0., 1.),
108+
T(dist.Gumbel, 0.5, 2.),
109+
T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])),
106110
T(dist.HalfCauchy, 1.),
107111
T(dist.HalfCauchy, np.array([1., 2.])),
108112
T(dist.HalfNormal, 1.),
@@ -937,3 +941,24 @@ def test_unpack_transform():
937941
z = transform.inv(y)
938942
assert_allclose(y['key'], x)
939943
assert_allclose(z, x)
944+
945+
946+
@pytest.mark.parametrize('jax_dist, sp_dist, params', CONTINUOUS)
947+
def test_generated_sample_distribution(jax_dist, sp_dist, params,
948+
N_sample=100_000,
949+
key=random.PRNGKey(11)):
950+
""" On samplers that we do not get directly from JAX, (e.g. we only get
951+
Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test
952+
agreement in the empirical distribution of generated samples between our
953+
samplers and those from SciPy.
954+
"""
955+
956+
if jax_dist not in [dist.Gumbel]:
957+
pytest.skip("{} sampling method taken from upstream, no need to"
958+
"test generated samples.".format(jax_dist.__name__))
959+
960+
jax_dist = jax_dist(*params)
961+
if sp_dist and not jax_dist.event_shape and not jax_dist.batch_shape:
962+
our_samples = jax_dist.sample(key, (N_sample,))
963+
ks_result = osp.kstest(our_samples, sp_dist(*params).cdf)
964+
assert ks_result.pvalue > 0.05

0 commit comments

Comments
 (0)