Skip to content

Commit 675c2b3

Browse files
author
Brendan Cooley
committed
feat(distributions.conjugate): support total_count_max in DirichletMultinomial
1 parent ab1f0dc commit 675c2b3

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

numpyro/distributions/conjugate.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class DirichletMultinomial(Distribution):
107107
:param numpy.ndarray concentration: concentration parameter (alpha) for the
108108
Dirichlet distribution.
109109
:param numpy.ndarray total_count: number of Categorical trials.
110+
:param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
110111
"""
111112

112113
arg_constraints = {
@@ -116,7 +117,9 @@ class DirichletMultinomial(Distribution):
116117
pytree_data_fields = ("concentration", "_dirichlet")
117118
pytree_aux_fields = ("total_count",)
118119

119-
def __init__(self, concentration, total_count=1, *, validate_args=None):
120+
def __init__(
121+
self, concentration, total_count=1, *, total_count_max=None, validate_args=None
122+
):
120123
if jnp.ndim(concentration) < 1:
121124
raise ValueError(
122125
"`concentration` parameter must be at least one-dimensional."
@@ -128,6 +131,7 @@ def __init__(self, concentration, total_count=1, *, validate_args=None):
128131
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
129132
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
130133
(self.total_count,) = promote_shapes(total_count, shape=batch_shape)
134+
self.total_count_max = total_count_max
131135
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
132136
self._dirichlet = Dirichlet(concentration)
133137
super().__init__(
@@ -140,9 +144,11 @@ def sample(self, key, sample_shape=()):
140144
assert is_prng_key(key)
141145
key_dirichlet, key_multinom = random.split(key)
142146
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
143-
return MultinomialProbs(total_count=self.total_count, probs=probs).sample(
144-
key_multinom
145-
)
147+
return MultinomialProbs(
148+
total_count=self.total_count,
149+
probs=probs,
150+
total_count_max=self.total_count_max,
151+
).sample(key_multinom)
146152

147153
@validate_sample
148154
def log_prob(self, value):

test/test_distributions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,6 +3416,23 @@ def f(x):
34163416
assert_allclose(x, y, rtol=1e-6)
34173417

34183418

3419+
def test_dirichlet_multinomial_abstract_total_count():
3420+
probs = jnp.array([0.2, 0.5, 0.3])
3421+
key = random.PRNGKey(0)
3422+
3423+
def f(x):
3424+
total_count = x.sum(-1)
3425+
return dist.DirichletMultinomial(
3426+
concentration=probs,
3427+
total_count=total_count,
3428+
total_count_max=10, # fails on 0.18.0
3429+
).sample(key)
3430+
3431+
x = dist.DirichletMultinomial(concentration=probs, total_count=10).sample(key)
3432+
y = jax.jit(f)(x)
3433+
assert_allclose(x, y, rtol=1e-6)
3434+
3435+
34193436
def test_normal_log_cdf():
34203437
# test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
34213438
# and if exp(log_cdf) agrees with cdf

0 commit comments

Comments
 (0)