@@ -107,6 +107,7 @@ class DirichletMultinomial(Distribution):
107107 :param numpy.ndarray concentration: concentration parameter (alpha) for the
108108 Dirichlet distribution.
109109 :param numpy.ndarray total_count: number of Categorical trials.
110+ :param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
110111 """
111112
112113 arg_constraints = {
@@ -116,7 +117,9 @@ class DirichletMultinomial(Distribution):
116117 pytree_data_fields = ("concentration" , "_dirichlet" )
117118 pytree_aux_fields = ("total_count" ,)
118119
119- def __init__ (self , concentration , total_count = 1 , * , validate_args = None ):
120+ def __init__ (
121+ self , concentration , total_count = 1 , * , total_count_max = None , validate_args = None
122+ ):
120123 if jnp .ndim (concentration ) < 1 :
121124 raise ValueError (
122125 "`concentration` parameter must be at least one-dimensional."
@@ -128,6 +131,7 @@ def __init__(self, concentration, total_count=1, *, validate_args=None):
128131 concentration_shape = batch_shape + jnp .shape (concentration )[- 1 :]
129132 (self .concentration ,) = promote_shapes (concentration , shape = concentration_shape )
130133 (self .total_count ,) = promote_shapes (total_count , shape = batch_shape )
134+ self .total_count_max = total_count_max
131135 concentration = jnp .broadcast_to (self .concentration , concentration_shape )
132136 self ._dirichlet = Dirichlet (concentration )
133137 super ().__init__ (
@@ -140,9 +144,11 @@ def sample(self, key, sample_shape=()):
140144 assert is_prng_key (key )
141145 key_dirichlet , key_multinom = random .split (key )
142146 probs = self ._dirichlet .sample (key_dirichlet , sample_shape )
143- return MultinomialProbs (total_count = self .total_count , probs = probs ).sample (
144- key_multinom
145- )
147+ return MultinomialProbs (
148+ total_count = self .total_count ,
149+ probs = probs ,
150+ total_count_max = self .total_count_max ,
151+ ).sample (key_multinom )
146152
147153 @validate_sample
148154 def log_prob (self , value ):
0 commit comments