|
23 | 23 | import scipy.stats as st
|
24 | 24 |
|
25 | 25 | from pytensor import shared
|
26 |
| -from pytensor.tensor import TensorVariable |
| 26 | +from pytensor.tensor import NoneConst, TensorVariable |
| 27 | +from pytensor.tensor.random.utils import normalize_size_param |
27 | 28 |
|
28 | 29 | import pymc as pm
|
29 | 30 |
|
|
43 | 44 | )
|
44 | 45 | from pymc.distributions.shape_utils import change_dist_size
|
45 | 46 | from pymc.logprob.basic import conditional_logp, logp
|
46 |
| -from pymc.pytensorf import compile |
| 47 | +from pymc.pytensorf import compile, normalize_rng_param |
47 | 48 | from pymc.testing import (
|
48 | 49 | BaseTestDistributionRandom,
|
49 | 50 | I,
|
@@ -210,6 +211,27 @@ def test_recreate_with_different_rng_inputs(self):
|
210 | 211 | new_next_rng, new_x = x.owner.op(*inputs)
|
211 | 212 | assert op.update(new_x.owner) == {new_rng: new_next_rng}
|
212 | 213 |
|
| 214 | + def test_change_dist_size_none(self): |
| 215 | + class TestRV(SymbolicRandomVariable): |
| 216 | + extended_signature = "[rng],[size]->[rng],(n)" |
| 217 | + |
| 218 | + @classmethod |
| 219 | + def rv_op(cls, size=None, rng=None): |
| 220 | + rng = normalize_rng_param(rng) |
| 221 | + size = normalize_size_param(size) |
| 222 | + next_rng, draws = Normal.dist(size=size, rng=rng).owner.outputs |
| 223 | + return cls(inputs=[rng, size], outputs=[next_rng, draws])(rng, size) |
| 224 | + |
| 225 | + size = NoneConst |
| 226 | + rv = TestRV.rv_op(size=size) |
| 227 | + assert rv.type.shape == () |
| 228 | + |
| 229 | + resized_rv = change_dist_size(rv, new_size=5) |
| 230 | + assert resized_rv.type.shape == (5,) |
| 231 | + |
| 232 | + resized_rv = change_dist_size(rv, new_size=5, expand=True) |
| 233 | + assert resized_rv.type.shape == (5,) |
| 234 | + |
213 | 235 |
|
214 | 236 | def test_tag_future_warning_dist():
|
215 | 237 | # Test no unexpected warnings
|
|
0 commit comments