Skip to content

Commit 025918f

Browse files
committed
Fix expand change_dist_size of SymbolicRandomVariables with size=None
1 parent a386c7a commit 025918f

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

pymc/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
400400

401401
params = op.dist_params(rv.owner)
402402

403-
if expand:
403+
if expand and not rv_size_is_none(size):
404404
new_size = tuple(new_size) + tuple(size)
405405

406406
return op.rv_op(*params, size=new_size)

tests/distributions/test_distribution.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import scipy.stats as st
2424

2525
from pytensor import shared
26-
from pytensor.tensor import TensorVariable
26+
from pytensor.tensor import NoneConst, TensorVariable
27+
from pytensor.tensor.random.utils import normalize_size_param
2728

2829
import pymc as pm
2930

@@ -43,7 +44,7 @@
4344
)
4445
from pymc.distributions.shape_utils import change_dist_size
4546
from pymc.logprob.basic import conditional_logp, logp
46-
from pymc.pytensorf import compile
47+
from pymc.pytensorf import compile, normalize_rng_param
4748
from pymc.testing import (
4849
BaseTestDistributionRandom,
4950
I,
@@ -210,6 +211,27 @@ def test_recreate_with_different_rng_inputs(self):
210211
new_next_rng, new_x = x.owner.op(*inputs)
211212
assert op.update(new_x.owner) == {new_rng: new_next_rng}
212213

214+
def test_change_dist_size_none(self):
215+
class TestRV(SymbolicRandomVariable):
216+
extended_signature = "[rng],[size]->[rng],(n)"
217+
218+
@classmethod
219+
def rv_op(cls, size=None, rng=None):
220+
rng = normalize_rng_param(rng)
221+
size = normalize_size_param(size)
222+
next_rng, draws = Normal.dist(size=size, rng=rng).owner.outputs
223+
return cls(inputs=[rng, size], outputs=[next_rng, draws])(rng, size)
224+
225+
size = NoneConst
226+
rv = TestRV.rv_op(size=size)
227+
assert rv.type.shape == ()
228+
229+
resized_rv = change_dist_size(rv, new_size=5)
230+
assert resized_rv.type.shape == (5,)
231+
232+
resized_rv = change_dist_size(rv, new_size=5, expand=True)
233+
assert resized_rv.type.shape == (5,)
234+
213235

214236
def test_tag_future_warning_dist():
215237
# Test no unexpected warnings

tests/distributions/test_shape_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,17 @@ def test_change_rv_size():
427427
assert tuple(rv_newer.shape.eval()) == (2,)
428428

429429

430+
def test_change_rv_size_expand_none_size():
431+
x = pt.random.normal()
432+
size = x.owner.op.size_param(x.owner)
433+
assert rv_size_is_none(size)
434+
new_x = change_dist_size(x, new_size=(2,), expand=True)
435+
new_size = new_x.owner.op.size_param(new_x.owner)
436+
assert not rv_size_is_none(new_size)
437+
assert new_size.data == [2]
438+
assert new_x.type.shape == (2,)
439+
440+
430441
def test_change_rv_size_default_update():
431442
rng = pytensor.shared(np.random.default_rng(0))
432443
x = normal(rng=rng)

0 commit comments

Comments
 (0)