Skip to content

Commit 12d2c36

Browse files
committed
SymbolicRandomVariable
1 parent 3d44733 commit 12d2c36

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1919
from pymc.distributions.distribution import Discrete
20-
from pymc.distributions.shape_utils import rv_size_is_none
20+
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
21+
from pymc.pytensorf import normalize_rng_param
2122
from pytensor import tensor as pt
23+
from pytensor.tensor.random.basic import beta as beta_rng
24+
from pytensor.tensor.random.basic import geometric as geometric_rng
2225
from pytensor.tensor.random.op import RandomVariable
26+
from pytensor.tensor.random.utils import normalize_size_param
2327

2428

2529
def log1mexp(x):
@@ -404,24 +408,28 @@ def dist(cls, mu1, mu2, **kwargs):
404408

405409
class ShiftedBetaGeometricRV(RandomVariable):
406410
name = "sbg"
411+
extended_signature = "[rng],[size],(),()->[rng],()"
407412
signature = "(),()->()"
408-
409-
dtype = "int64"
410413
_print_name = ("ShiftedBetaGeometric", "\\operatorname{ShiftedBetaGeometric}")
411414

412415
@classmethod
413-
def rng_fn(cls, rng, alpha, beta, size):
414-
if size is None:
415-
size = np.broadcast_shapes(alpha.shape, beta.shape)
416+
def rv_op(cls, alpha, beta, *, size=None, rng=None):
417+
alpha = pt.as_tensor(alpha)
418+
beta = pt.as_tensor(beta)
419+
rng = normalize_rng_param(rng)
420+
size = normalize_size_param(size)
416421

417-
alpha = np.broadcast_to(alpha, size)
418-
beta = np.broadcast_to(beta, size)
422+
if rv_size_is_none(size):
423+
size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params)
419424

420-
p = rng.beta(a=alpha, b=beta, size=size)
425+
next_rng, p = beta_rng(a=alpha, b=beta, size=size, rng=rng).owner.outputs
421426

422-
samples = rng.geometric(p, size=size)
427+
draws = geometric_rng(p, size=size)
428+
draws = draws.astype("int64")
423429

424-
return samples
430+
return cls(inputs=[rng, size, alpha, beta], outputs=[next_rng, draws])(
431+
rng, size, alpha, beta
432+
)
425433

426434

427435
sbg = ShiftedBetaGeometricRV()
@@ -521,7 +529,7 @@ def logcdf(value, alpha, beta):
521529
+ pt.gammaln(alpha + beta)
522530
- pt.gammaln(alpha + beta + value)
523531
)
524-
# log(1-exp())
532+
# log(1-exp(logS))
525533
return pt.log1mexp(logS)
526534

527535
def support_point(rv, size, alpha, beta):

0 commit comments

Comments
 (0)