Skip to content

fix(distributions.conjugate): support total_count_max in DirichletMultinomial distribution #2016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added ar2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
digraph {
x [label=x fillcolor=white shape=ellipse style=filled]
}
16 changes: 11 additions & 5 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,19 @@ class DirichletMultinomial(Distribution):
:param numpy.ndarray concentration: concentration parameter (alpha) for the
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
:param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
}
pytree_data_fields = ("concentration", "_dirichlet")
pytree_aux_fields = ("total_count",)
pytree_aux_fields = ("total_count", "total_count_max")

def __init__(self, concentration, total_count=1, *, validate_args=None):
def __init__(
self, concentration, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
Expand All @@ -128,6 +131,7 @@ def __init__(self, concentration, total_count=1, *, validate_args=None):
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
(self.total_count,) = promote_shapes(total_count, shape=batch_shape)
self.total_count_max = total_count_max
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
self._dirichlet = Dirichlet(concentration)
super().__init__(
Expand All @@ -140,9 +144,11 @@ def sample(self, key, sample_shape=()):
assert is_prng_key(key)
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
return MultinomialProbs(total_count=self.total_count, probs=probs).sample(
key_multinom
)
return MultinomialProbs(
total_count=self.total_count,
probs=probs,
total_count_max=self.total_count_max,
).sample(key_multinom)

@validate_sample
def log_prob(self, value):
Expand Down
Binary file added ssbvm_mixture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,6 +3416,23 @@ def f(x):
assert_allclose(x, y, rtol=1e-6)


def test_dirichlet_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)

def f(x):
total_count = x.sum(-1)
return dist.DirichletMultinomial(
concentration=probs,
total_count=total_count,
total_count_max=10, # fails on 0.18.0
).sample(key)

x = dist.DirichletMultinomial(concentration=probs, total_count=10).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)


def test_normal_log_cdf():
# test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
# and if exp(log_cdf) agrees with cdf
Expand Down
8 changes: 7 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@
"hsgp.py --num-samples 10 --num-warmup 10 --num-chains 2",
"minipyro.py",
"mortality.py --num-samples 10 --num-warmup 10 --num-chains 2",
"neutra.py --num-samples 100 --num-warmup 100",
pytest.param(
"neutra.py --num-samples 100 --num-warmup 100",
marks=pytest.mark.skipif(
"CI" in os.environ,
reason="This example fails on the CI runner with message 'died with <Signals.SIGSEGV: 11>.'",
),
),
"ode.py --num-samples 100 --num-warmup 100 --num-chains 1",
pytest.param(
"prodlda.py --num-steps 10 --hidden 10 --nn-framework flax",
Expand Down
Binary file added wordclouds.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.