|
17 | 17 |
|
18 | 18 | from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
|
19 | 19 | from pymc.distributions.distribution import Discrete
|
20 |
| -from pymc.distributions.shape_utils import rv_size_is_none |
| 20 | +from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none |
| 21 | +from pymc.pytensorf import normalize_rng_param |
21 | 22 | from pytensor import tensor as pt
|
| 23 | +from pytensor.tensor.random.basic import beta as beta_rng |
| 24 | +from pytensor.tensor.random.basic import geometric as geometric_rng |
22 | 25 | from pytensor.tensor.random.op import RandomVariable
|
| 26 | +from pytensor.tensor.random.utils import normalize_size_param |
23 | 27 |
|
24 | 28 |
|
25 | 29 | def log1mexp(x):
|
@@ -404,24 +408,28 @@ def dist(cls, mu1, mu2, **kwargs):
|
404 | 408 |
|
405 | 409 | class ShiftedBetaGeometricRV(RandomVariable):
|
406 | 410 | name = "sbg"
|
| 411 | + extended_signature = "[rng],[size],(),()->[rng],()" |
407 | 412 | signature = "(),()->()"
|
408 |
| - |
409 |
| - dtype = "int64" |
410 | 413 | _print_name = ("ShiftedBetaGeometric", "\\operatorname{ShiftedBetaGeometric}")
|
411 | 414 |
|
412 | 415 | @classmethod
|
413 |
| - def rng_fn(cls, rng, alpha, beta, size): |
414 |
| - if size is None: |
415 |
| - size = np.broadcast_shapes(alpha.shape, beta.shape) |
| 416 | + def rv_op(cls, alpha, beta, *, size=None, rng=None): |
| 417 | + alpha = pt.as_tensor(alpha) |
| 418 | + beta = pt.as_tensor(beta) |
| 419 | + rng = normalize_rng_param(rng) |
| 420 | + size = normalize_size_param(size) |
416 | 421 |
|
417 |
| - alpha = np.broadcast_to(alpha, size) |
418 |
| - beta = np.broadcast_to(beta, size) |
| 422 | + if rv_size_is_none(size): |
| 423 | + size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params) |
419 | 424 |
|
420 |
| - p = rng.beta(a=alpha, b=beta, size=size) |
| 425 | + next_rng, p = beta_rng(a=alpha, b=beta, size=size, rng=rng).owner.outputs |
421 | 426 |
|
422 |
| - samples = rng.geometric(p, size=size) |
| 427 | + draws = geometric_rng(p, size=size) |
| 428 | + draws = draws.astype("int64") |
423 | 429 |
|
424 |
| - return samples |
| 430 | + return cls(inputs=[rng, size, alpha, beta], outputs=[next_rng, draws])( |
| 431 | + rng, size, alpha, beta |
| 432 | + ) |
425 | 433 |
|
426 | 434 |
|
427 | 435 | sbg = ShiftedBetaGeometricRV()
|
@@ -521,7 +529,7 @@ def logcdf(value, alpha, beta):
|
521 | 529 | + pt.gammaln(alpha + beta)
|
522 | 530 | - pt.gammaln(alpha + beta + value)
|
523 | 531 | )
|
524 |
| - # log(1-exp()) |
| 532 | + # log(1-exp(logS)) |
525 | 533 | return pt.log1mexp(logS)
|
526 | 534 |
|
527 | 535 | def support_point(rv, size, alpha, beta):
|
|
0 commit comments