diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index e20d99c605..50c9bf4578 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -64,7 +64,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable: @numba_core_rv_funcify.register(ptr.LaplaceRV) @numba_core_rv_funcify.register(ptr.BinomialRV) @numba_core_rv_funcify.register(ptr.NegBinomialRV) -@numba_core_rv_funcify.register(ptr.MultinomialRV) @numba_core_rv_funcify.register(ptr.PermutationRV) @numba_core_rv_funcify.register(ptr.IntegersRV) def numba_core_rv_default(op, node): @@ -132,6 +131,15 @@ def random(rng, b, scale): return random +@numba_core_rv_funcify.register(ptr.InvGammaRV) +def numba_core_InvGammaRV(op, node): + @numba_basic.numba_njit + def random(rng, shape, scale): + return 1 / rng.gamma(shape, 1 / scale) + + return random + + @numba_core_rv_funcify.register(ptr.CategoricalRV) def core_CategoricalRV(op, node): @numba_basic.numba_njit @@ -142,6 +150,29 @@ def random_fn(rng, p): return random_fn +@numba_core_rv_funcify.register(ptr.MultinomialRV) +def core_MultinomialRV(op, node): + dtype = op.dtype + + @numba_basic.numba_njit + def random_fn(rng, n, p): + n_cat = p.shape[0] + draws = np.zeros(n_cat, dtype=dtype) + remaining_p = np.float64(1.0) + remaining_n = n + for i in range(n_cat - 1): + draws[i] = rng.binomial(remaining_n, p[i] / remaining_p) + remaining_n -= draws[i] + if remaining_n <= 0: + break + remaining_p -= p[i] + if remaining_n > 0: + draws[n_cat - 1] = remaining_n + return draws + + return random_fn + + @numba_core_rv_funcify.register(ptr.MvNormalRV) def core_MvNormalRV(op, node): method = op.method diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 6d6a4ee270..10d5343511 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1219,7 +1219,7 @@ def rng_fn_scipy(cls, rng, loc, scale, size): halfcauchy = HalfCauchyRV() -class InvGammaRV(ScipyRandomVariable): +class InvGammaRV(RandomVariable): r"""An inverse-gamma continuous random variable. The probability density function for `invgamma` in terms of its shape @@ -1266,8 +1266,8 @@ def __call__(self, shape, scale, size=None, **kwargs): return super().__call__(shape, scale, size=size, **kwargs) @classmethod - def rng_fn_scipy(cls, rng, shape, scale, size): - return stats.invgamma.rvs(shape, scale=scale, size=size, random_state=rng) + def rng_fn(cls, rng, shape, scale, size): + return 1 / rng.gamma(shape, 1 / scale, size) invgamma = InvGammaRV() diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index f52b1e2800..9443775a39 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -514,6 +514,31 @@ def test_multivariate_normal(): ], (pt.as_tensor([2, 1])), ), + ( + ptr.invgamma, + [ + ( + pt.dvector("shape"), + np.array([1.0, 2.0], dtype=np.float64), + ), + ( + pt.dvector("scale"), + np.array([0.5, 3.0], dtype=np.float64), + ), + ], + (2,), + ), + ( + ptr.multinomial, + [ + ( + pt.lvector("n"), + np.array([1, 10, 1000], dtype=np.int64), + ), + (pt.dvector("p"), np.array([0.3, 0.7], dtype=np.float64)), + ], + None, + ), ], ids=str, ) @@ -627,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ def test_DirichletRV(a, size, cm): a, a_val = a rng = shared(np.random.default_rng(29402)) - g = ptr.dirichlet(a, size=size, rng=rng) - g_fn = function([a], g, mode=numba_mode) + next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs + g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng}) with cm: - all_samples = [] - for i in range(1000): - samples = g_fn(a_val) - all_samples.append(samples) - + all_samples = [g_fn(a_val) for _ in range(1000)] exp_res = a_val / a_val.sum(-1) res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1))) assert np.allclose(res, exp_res, atol=1e-4)