@@ -107,6 +107,7 @@ class DirichletMultinomial(Distribution):
107
107
:param numpy.ndarray concentration: concentration parameter (alpha) for the
108
108
Dirichlet distribution.
109
109
:param numpy.ndarray total_count: number of Categorical trials.
110
+ :param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
110
111
"""
111
112
112
113
arg_constraints = {
@@ -116,7 +117,9 @@ class DirichletMultinomial(Distribution):
116
117
pytree_data_fields = ("concentration" , "_dirichlet" )
117
118
pytree_aux_fields = ("total_count" ,)
118
119
119
- def __init__ (self , concentration , total_count = 1 , * , validate_args = None ):
120
+ def __init__ (
121
+ self , concentration , total_count = 1 , * , total_count_max = None , validate_args = None
122
+ ):
120
123
if jnp .ndim (concentration ) < 1 :
121
124
raise ValueError (
122
125
"`concentration` parameter must be at least one-dimensional."
@@ -128,6 +131,7 @@ def __init__(self, concentration, total_count=1, *, validate_args=None):
128
131
concentration_shape = batch_shape + jnp .shape (concentration )[- 1 :]
129
132
(self .concentration ,) = promote_shapes (concentration , shape = concentration_shape )
130
133
(self .total_count ,) = promote_shapes (total_count , shape = batch_shape )
134
+ self .total_count_max = total_count_max
131
135
concentration = jnp .broadcast_to (self .concentration , concentration_shape )
132
136
self ._dirichlet = Dirichlet (concentration )
133
137
super ().__init__ (
@@ -140,9 +144,11 @@ def sample(self, key, sample_shape=()):
140
144
assert is_prng_key (key )
141
145
key_dirichlet , key_multinom = random .split (key )
142
146
probs = self ._dirichlet .sample (key_dirichlet , sample_shape )
143
- return MultinomialProbs (total_count = self .total_count , probs = probs ).sample (
144
- key_multinom
145
- )
147
+ return MultinomialProbs (
148
+ total_count = self .total_count ,
149
+ probs = probs ,
150
+ total_count_max = self .total_count_max ,
151
+ ).sample (key_multinom )
146
152
147
153
@validate_sample
148
154
def log_prob (self , value ):
0 commit comments