Skip to content

Commit 72f255d

Browse files
committed
Do not set default_updates in RandomVariables returned by PyMC distributions
1 parent 4984ef3 commit 72f255d

File tree

2 files changed

+15
-39
lines changed

2 files changed

+15
-39
lines changed

pymc/distributions/distribution.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from aesara.tensor.basic import as_tensor_variable
3131
from aesara.tensor.elemwise import Elemwise
3232
from aesara.tensor.random.op import RandomVariable
33-
from aesara.tensor.random.var import RandomStateSharedVariable
3433
from aesara.tensor.var import TensorVariable
3534
from typing_extensions import TypeAlias
3635

@@ -358,23 +357,6 @@ def dist(
358357
replicate_shape = cast(StrongShape, shape[:-1])
359358
rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True)
360359

361-
rng = kwargs.pop("rng", None)
362-
if (
363-
rv_out.owner
364-
and isinstance(rv_out.owner.op, RandomVariable)
365-
and isinstance(rng, RandomStateSharedVariable)
366-
and not getattr(rng, "default_update", None)
367-
):
368-
# This tells `aesara.function` that the shared RNG variable
369-
# is mutable, which--in turn--tells the `FunctionGraph`
370-
# `Supervisor` feature to allow in-place updates on the variable.
371-
# Without it, the `RandomVariable`s could not be optimized to allow
372-
# in-place RNG updates, forcing all sample results from compiled
373-
# functions to be the same on repeated evaluations.
374-
new_rng = rv_out.owner.outputs[0]
375-
rv_out.update = (rng, new_rng)
376-
rng.default_update = new_rng
377-
378360
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
379361
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
380362
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
@@ -589,27 +571,6 @@ def dist(
589571
replicate_shape = cast(StrongShape, shape[:-1])
590572
graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)
591573

592-
rngs = kwargs.pop("rngs", None)
593-
if rngs is not None:
594-
graph_rvs = cls.graph_rvs(graph)
595-
assert len(rngs) == len(graph_rvs)
596-
for rng, rv_out in zip(rngs, graph_rvs):
597-
if (
598-
rv_out.owner
599-
and isinstance(rv_out.owner.op, RandomVariable)
600-
and isinstance(rng, RandomStateSharedVariable)
601-
and not getattr(rng, "default_update", None)
602-
):
603-
# This tells `aesara.function` that the shared RNG variable
604-
# is mutable, which--in turn--tells the `FunctionGraph`
605-
# `Supervisor` feature to allow in-place updates on the variable.
606-
# Without it, the `RandomVariable`s could not be optimized to allow
607-
# in-place RNG updates, forcing all sample results from compiled
608-
# functions to be the same on repeated evaluations.
609-
new_rng = rv_out.owner.outputs[0]
610-
rv_out.update = (rng, new_rng)
611-
rng.default_update = new_rng
612-
613574
# TODO: Create new attr error stating that these are not available for DerivedDistribution
614575
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
615576
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

pymc/tests/test_shape_handling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,3 +470,18 @@ def test_lazy_flavors(self):
470470
def test_invalid_flavors(self):
471471
with pytest.raises(ValueError, match="Passing both"):
472472
pm.Normal.dist(0, 1, shape=(3,), size=(3,))
473+
474+
def test_size_from_dims_rng_update(self):
475+
"""Test that when setting size from dims we update the rng properly
476+
See https://github.com/pymc-devs/pymc/issues/5653
477+
"""
478+
with pm.Model(coords=dict(x_dim=range(2))):
479+
x = pm.Normal("x", dims=("x_dim",))
480+
481+
fn = pm.aesaraf.compile_pymc([], x)
482+
# Check that both function outputs (rng and draws) come from the same Apply node
483+
assert fn.maker.fgraph.outputs[0].owner is fn.maker.fgraph.outputs[1].owner
484+
485+
# Confirm that the rng is properly offset, otherwise the second value of the first
486+
# draw, would match the first value of the second draw
487+
assert fn()[1] != fn()[0]

0 commit comments

Comments
 (0)