File tree Expand file tree Collapse file tree 2 files changed +27
-0
lines changed Expand file tree Collapse file tree 2 files changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -189,6 +189,11 @@ def change_rv_size(
189
189
for k , v in tag .__dict__ .items ():
190
190
new_rv .tag .__dict__ .setdefault (k , v )
191
191
192
+ # Update "traditional" rng default_update, if that was set for old RV
193
+ default_update = getattr (rng , "default_update" , None )
194
+ if default_update is not None and default_update is rv_node .outputs [0 ]:
195
+ rng .default_update = new_rv_node .outputs [0 ]
196
+
192
197
if config .compute_test_value != "off" :
193
198
compute_test_value (new_rv_node )
194
199
Original file line number Diff line number Diff line change @@ -94,6 +94,28 @@ def test_change_rv_size():
94
94
assert tuple (rv_newer .shape .eval ()) == (2 ,)
95
95
96
96
97
+ def test_change_rv_size_default_update ():
98
+ rng = aesara .shared (np .random .default_rng (0 ))
99
+ x = normal (rng = rng )
100
+
101
+ # Test that "traditional" default_update is updated
102
+ rng .default_update = x .owner .outputs [0 ]
103
+ new_x = change_rv_size (x , new_size = (2 ,))
104
+ assert rng .default_update is not x .owner .outputs [0 ]
105
+ assert rng .default_update is new_x .owner .outputs [0 ]
106
+
107
+ # Test that "non-traditional" default_update is left unchanged
108
+ next_rng = aesara .shared (np .random .default_rng (1 ))
109
+ rng .default_update = next_rng
110
+ new_x = change_rv_size (x , new_size = (2 ,))
111
+ assert rng .default_update is next_rng
112
+
113
+ # Test that default_update is not set if there was none before
114
+ del rng .default_update
115
+ new_x = change_rv_size (x , new_size = (2 ,))
116
+ assert not hasattr (rng , "default_update" )
117
+
118
+
97
119
class TestBroadcasting :
98
120
def test_make_shared_replacements (self ):
99
121
"""Check if pm.make_shared_replacements preserves broadcasting."""
You can’t perform that action at this time.
0 commit comments