Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added ar2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
digraph {
x [label=x fillcolor=white shape=ellipse style=filled]
}
16 changes: 11 additions & 5 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,19 @@ class DirichletMultinomial(Distribution):
:param numpy.ndarray concentration: concentration parameter (alpha) for the
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
:param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
}
pytree_data_fields = ("concentration", "_dirichlet")
pytree_aux_fields = ("total_count",)
pytree_aux_fields = ("total_count", "total_count_max")

def __init__(self, concentration, total_count=1, *, validate_args=None):
def __init__(
self, concentration, total_count=1, *, total_count_max=None, validate_args=None
):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
Expand All @@ -128,6 +131,7 @@ def __init__(self, concentration, total_count=1, *, validate_args=None):
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
(self.total_count,) = promote_shapes(total_count, shape=batch_shape)
self.total_count_max = total_count_max
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
self._dirichlet = Dirichlet(concentration)
super().__init__(
Expand All @@ -140,9 +144,11 @@ def sample(self, key, sample_shape=()):
assert is_prng_key(key)
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
return MultinomialProbs(total_count=self.total_count, probs=probs).sample(
key_multinom
)
return MultinomialProbs(
total_count=self.total_count,
probs=probs,
total_count_max=self.total_count_max,
).sample(key_multinom)

@validate_sample
def log_prob(self, value):
Expand Down
Binary file added ssbvm_mixture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,6 +3416,23 @@ def f(x):
assert_allclose(x, y, rtol=1e-6)


def test_dirichlet_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)

def f(x):
total_count = x.sum(-1)
return dist.DirichletMultinomial(
concentration=probs,
total_count=total_count,
total_count_max=10, # fails on 0.18.0
).sample(key)

x = dist.DirichletMultinomial(concentration=probs, total_count=10).sample(key)
y = jax.jit(f)(x)
assert_allclose(x, y, rtol=1e-6)


def test_normal_log_cdf():
# test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
# and if exp(log_cdf) agrees with cdf
Expand Down
8 changes: 7 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@
"hsgp.py --num-samples 10 --num-warmup 10 --num-chains 2",
"minipyro.py",
"mortality.py --num-samples 10 --num-warmup 10 --num-chains 2",
"neutra.py --num-samples 100 --num-warmup 100",
pytest.param(
"neutra.py --num-samples 100 --num-warmup 100",
marks=pytest.mark.skipif(
"CI" in os.environ,
reason="This example fails on the CI runner with message 'died with <Signals.SIGSEGV: 11>.'",
),
),
"ode.py --num-samples 100 --num-warmup 100 --num-chains 1",
pytest.param(
"prodlda.py --num-steps 10 --hidden 10 --nn-framework flax",
Expand Down
Binary file added wordclouds.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.