Skip to content

Commit f8581ab

Browse files
committed
Allow creating CustomDist inside another CustomDist
1 parent 13e7c88 commit f8581ab

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
5353
from pymc.logprob.basic import logp
5454
from pymc.logprob.rewriting import logprob_rewrites_db
55-
from pymc.model import BlockModelAccess
55+
from pymc.model import new_or_existing_block_model_access
5656
from pymc.printing import str_for_dist
5757
from pymc.pytensorf import collect_default_updates, convert_observed_data, floatX
5858
from pymc.util import UNSET, _add_future_warning_tag
@@ -645,7 +645,7 @@ def rv_op(
645645
size = normalize_size_param(size)
646646
dummy_size_param = size.type()
647647
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
648-
with BlockModelAccess(
648+
with new_or_existing_block_model_access(
649649
error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
650650
):
651651
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
@@ -1048,7 +1048,7 @@ def is_symbolic_random(self, random, dist_params):
10481048
# Try calling random with symbolic inputs
10491049
try:
10501050
size = normalize_size_param(None)
1051-
with BlockModelAccess(
1051+
with new_or_existing_block_model_access(
10521052
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API to create such variables."
10531053
):
10541054
out = random(*dist_params, size)

pymc/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
cls._context_class = context_class
145145
super().__init__(name, bases, nmspc)
146146

147-
def get_context(cls, error_if_none=True) -> Optional[T]:
147+
def get_context(cls, error_if_none=True, allow_block_model_access=False) -> Optional[T]:
148148
"""Return the most recently pushed context object of type ``cls``
149149
on the stack, or ``None``. If ``error_if_none`` is True (default),
150150
raise a ``TypeError`` instead of returning ``None``."""
@@ -156,7 +156,7 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
156156
if error_if_none:
157157
raise TypeError(f"No {cls} on context stack")
158158
return None
159-
if isinstance(candidate, BlockModelAccess):
159+
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
160160
raise BlockModelAccessError(candidate.error_msg_on_access)
161161
return candidate
162162

@@ -1892,6 +1892,14 @@ def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwarg
18921892
self.error_msg_on_access = error_msg_on_access
18931893

18941894

1895+
def new_or_existing_block_model_access(*args, **kwargs):
1896+
"""Return a BlockModelAccess in the stack or create a new one if none is found."""
1897+
model = Model.get_context(error_if_none=False, allow_block_model_access=True)
1898+
if isinstance(model, BlockModelAccess):
1899+
return model
1900+
return BlockModelAccess(*args, **kwargs)
1901+
1902+
18951903
def set_data(new_data, model=None, *, coords=None):
18961904
"""Sets the value of one or more data container variables. Note that the shape is also
18971905
dynamic, it is updated when the value is changed. See the examples below for two common

tests/distributions/test_distribution.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,22 @@ def dist(size):
599599

600600
assert pm.CustomDist.dist(dist=dist)
601601

602+
def test_nested_custom_dist(self):
603+
"""Test we can create CustomDist that creates another CustomDist"""
604+
605+
def dist(size=None):
606+
def inner_dist(size=None):
607+
return pm.Normal.dist(size=size)
608+
609+
inner_dist = pm.CustomDist.dist(dist=inner_dist, size=size)
610+
return pt.exp(inner_dist)
611+
612+
rv = pm.CustomDist.dist(dist=dist)
613+
np.testing.assert_allclose(
614+
pm.logp(rv, 1.0).eval(),
615+
pm.logp(pm.LogNormal.dist(), 1.0).eval(),
616+
)
617+
602618

603619
class TestSymbolicRandomVariable:
604620
def test_inline(self):

0 commit comments

Comments
 (0)