Skip to content

Commit 65c1c68

Browse files
authored
Register Op as subclass of Distributions with rv_type defined (#6493)
This is the case of distributions that return SymbolicRandomVariables
1 parent 26048a4 commit 65c1c68

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

pymc/distributions/distribution.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,14 @@ def _random(*args, **kwargs):
112112
clsdict["random"] = _random
113113

114114
rv_op = clsdict.setdefault("rv_op", None)
115-
rv_type = None
115+
rv_type = clsdict.setdefault("rv_type", None)
116116

117117
if isinstance(rv_op, RandomVariable):
118-
rv_type = type(rv_op)
119-
clsdict["rv_type"] = rv_type
118+
if rv_type is not None:
119+
assert isinstance(rv_op, rv_type)
120+
else:
121+
rv_type = type(rv_op)
122+
clsdict["rv_type"] = rv_type
120123

121124
new_cls = super().__new__(cls, name, bases, clsdict)
122125

@@ -155,8 +158,8 @@ def icdf(op, value, *dist_params, **kwargs):
155158
def moment(op, rv, rng, size, dtype, *dist_params):
156159
return class_moment(rv, size, *dist_params)
157160

158-
# Register the PyTensor `RandomVariable` type as a subclass of this
159-
# `Distribution` type.
161+
# Register the PyTensor rv_type as a subclass of this
162+
# PyMC Distribution type.
160163
new_cls.register(rv_type)
161164

162165
return new_cls

pymc/tests/distributions/test_distribution.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pymc as pm
2727

2828
from pymc.distributions import (
29+
Censored,
2930
DiracDelta,
3031
Flat,
3132
HalfNormal,
@@ -580,3 +581,9 @@ def test_tag_future_warning_dist():
580581
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
581582
value_var = new_x.tag.value_var
582583
assert value_var == "1"
584+
585+
586+
def test_distribution_op_registered():
587+
"""Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions."""
588+
assert isinstance(Normal.dist().owner.op, Normal)
589+
assert isinstance(Censored.dist(Normal.dist(), lower=None, upper=None).owner.op, Censored)

0 commit comments

Comments
 (0)