Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
@numba_core_rv_funcify.register(ptr.LaplaceRV)
@numba_core_rv_funcify.register(ptr.BinomialRV)
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
@numba_core_rv_funcify.register(ptr.MultinomialRV)
@numba_core_rv_funcify.register(ptr.PermutationRV)
@numba_core_rv_funcify.register(ptr.IntegersRV)
def numba_core_rv_default(op, node):
Expand Down Expand Up @@ -132,6 +131,15 @@
return random


@numba_core_rv_funcify.register(ptr.InvGammaRV)
def numba_core_InvGammaRV(op, node):
@numba_basic.numba_njit
def random(rng, shape, scale):
return 1 / rng.gamma(shape, 1 / scale)

Check warning on line 138 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L138

Added line #L138 was not covered by tests

return random


@numba_core_rv_funcify.register(ptr.CategoricalRV)
def core_CategoricalRV(op, node):
@numba_basic.numba_njit
Expand All @@ -142,6 +150,29 @@
return random_fn


@numba_core_rv_funcify.register(ptr.MultinomialRV)
def core_MultinomialRV(op, node):
dtype = op.dtype

@numba_basic.numba_njit
def random_fn(rng, n, p):
n_cat = p.shape[0]
draws = np.zeros(n_cat, dtype=dtype)
remaining_p = np.float64(1.0)
remaining_n = n

Check warning on line 162 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L159-L162

Added lines #L159 - L162 were not covered by tests
for i in range(n_cat - 1):
draws[i] = rng.binomial(remaining_n, p[i] / remaining_p)
remaining_n -= draws[i]

Check warning on line 165 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L164-L165

Added lines #L164 - L165 were not covered by tests
if remaining_n <= 0:
break
remaining_p -= p[i]

Check warning on line 168 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L167-L168

Added lines #L167 - L168 were not covered by tests
if remaining_n > 0:
draws[n_cat - 1] = remaining_n
return draws

Check warning on line 171 in pytensor/link/numba/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/random.py#L170-L171

Added lines #L170 - L171 were not covered by tests

return random_fn


@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
method = op.method
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ def rng_fn_scipy(cls, rng, loc, scale, size):
halfcauchy = HalfCauchyRV()


class InvGammaRV(ScipyRandomVariable):
class InvGammaRV(RandomVariable):
r"""An inverse-gamma continuous random variable.

The probability density function for `invgamma` in terms of its shape
Expand Down Expand Up @@ -1266,8 +1266,8 @@ def __call__(self, shape, scale, size=None, **kwargs):
return super().__call__(shape, scale, size=size, **kwargs)

@classmethod
def rng_fn_scipy(cls, rng, shape, scale, size):
return stats.invgamma.rvs(shape, scale=scale, size=size, random_state=rng)
def rng_fn(cls, rng, shape, scale, size):
return 1 / rng.gamma(shape, 1 / scale, size)


invgamma = InvGammaRV()
Expand Down
35 changes: 28 additions & 7 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,31 @@ def test_multivariate_normal():
],
(pt.as_tensor([2, 1])),
),
(
ptr.invgamma,
[
(
pt.dvector("shape"),
np.array([1.0, 2.0], dtype=np.float64),
),
(
pt.dvector("scale"),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
),
(
ptr.multinomial,
[
(
pt.lvector("n"),
np.array([1, 10, 1000], dtype=np.int64),
),
(pt.dvector("p"), np.array([0.3, 0.7], dtype=np.float64)),
],
None,
),
],
ids=str,
)
Expand Down Expand Up @@ -627,15 +652,11 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
def test_DirichletRV(a, size, cm):
a, a_val = a
rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode)
next_rng, g = ptr.dirichlet(a, size=size, rng=rng).owner.outputs
g_fn = function([a], g, mode=numba_mode, updates={rng: next_rng})

with cm:
all_samples = []
for i in range(1000):
samples = g_fn(a_val)
all_samples.append(samples)

all_samples = [g_fn(a_val) for _ in range(1000)]
exp_res = a_val / a_val.sum(-1)
res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1)))
assert np.allclose(res, exp_res, atol=1e-4)
Expand Down