@@ -64,7 +64,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
64
64
@numba_core_rv_funcify .register (ptr .LaplaceRV )
65
65
@numba_core_rv_funcify .register (ptr .BinomialRV )
66
66
@numba_core_rv_funcify .register (ptr .NegBinomialRV )
67
- @numba_core_rv_funcify .register (ptr .MultinomialRV )
68
67
@numba_core_rv_funcify .register (ptr .PermutationRV )
69
68
@numba_core_rv_funcify .register (ptr .IntegersRV )
70
69
def numba_core_rv_default (op , node ):
@@ -132,6 +131,15 @@ def random(rng, b, scale):
132
131
return random
133
132
134
133
134
+ @numba_core_rv_funcify .register (ptr .InvGammaRV )
135
+ def numba_core_InvGammaRV (op , node ):
136
+ @numba_basic .numba_njit
137
+ def random (rng , shape , scale ):
138
+ return 1 / rng .gamma (shape , 1 / scale )
139
+
140
+ return random
141
+
142
+
135
143
@numba_core_rv_funcify .register (ptr .CategoricalRV )
136
144
def core_CategoricalRV (op , node ):
137
145
@numba_basic .numba_njit
@@ -142,6 +150,29 @@ def random_fn(rng, p):
142
150
return random_fn
143
151
144
152
153
+ @numba_core_rv_funcify .register (ptr .MultinomialRV )
154
+ def core_MultinomialRV (op , node ):
155
+ dtype = op .dtype
156
+
157
+ @numba_basic .numba_njit
158
+ def random_fn (rng , n , p ):
159
+ n_cat = p .shape [0 ]
160
+ draws = np .zeros (n_cat , dtype = dtype )
161
+ remaining_p = np .float64 (1.0 )
162
+ remaining_n = n
163
+ for i in range (n_cat - 1 ):
164
+ draws [i ] = rng .binomial (remaining_n , p [i ] / remaining_p )
165
+ remaining_n -= draws [i ]
166
+ if remaining_n <= 0 :
167
+ break
168
+ remaining_p -= p [i ]
169
+ if remaining_n > 0 :
170
+ draws [n_cat - 1 ] = remaining_n
171
+ return draws
172
+
173
+ return random_fn
174
+
175
+
145
176
@numba_core_rv_funcify .register (ptr .MvNormalRV )
146
177
def core_MvNormalRV (op , node ):
147
178
method = op .method
0 commit comments