Skip to content

Commit 5392468

Browse files
committed
Update "traditional" rng default_update in change_rv_size
A "traditional" update is `rng.default_update = rv_constructor(rng=rng).owner.outputs[0]`
1 parent b9cdd5c commit 5392468

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

pymc/aesaraf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def change_rv_size(
189189
for k, v in tag.__dict__.items():
190190
new_rv.tag.__dict__.setdefault(k, v)
191191

192+
# Update "traditional" rng default_update, if that was set for old RV
193+
default_update = getattr(rng, "default_update", None)
194+
if default_update is not None and default_update is rv_node.outputs[0]:
195+
rng.default_update = new_rv_node.outputs[0]
196+
192197
if config.compute_test_value != "off":
193198
compute_test_value(new_rv_node)
194199

pymc/tests/test_aesaraf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,28 @@ def test_change_rv_size():
9494
assert tuple(rv_newer.shape.eval()) == (2,)
9595

9696

97+
def test_change_rv_size_default_update():
98+
rng = aesara.shared(np.random.default_rng(0))
99+
x = normal(rng=rng)
100+
101+
# Test that "traditional" default_update is updated
102+
rng.default_update = x.owner.outputs[0]
103+
new_x = change_rv_size(x, new_size=(2,))
104+
assert rng.default_update is not x.owner.outputs[0]
105+
assert rng.default_update is new_x.owner.outputs[0]
106+
107+
# Test that "non-traditional" default_update is left unchanged
108+
next_rng = aesara.shared(np.random.default_rng(1))
109+
rng.default_update = next_rng
110+
new_x = change_rv_size(x, new_size=(2,))
111+
assert rng.default_update is next_rng
112+
113+
# Test that default_update is not set if there was none before
114+
del rng.default_update
115+
new_x = change_rv_size(x, new_size=(2,))
116+
assert not hasattr(rng, "default_update")
117+
118+
97119
class TestBroadcasting:
98120
def test_make_shared_replacements(self):
99121
"""Check if pm.make_shared_replacements preserves broadcasting."""

0 commit comments

Comments
 (0)