Skip to content

Commit 72534c7

Browse files
committed
Constant fold original RV.shape in graph of joined PartialObservedRV
Not doing this would lead to no default updates for draws of such RVs, due to the same RNG being reused twice.
1 parent 904a0ea commit 72534c7

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,8 @@ def create_partial_observed_rv(
13371337
ndim_supp=rv.owner.op.ndim_supp,
13381338
)(rv, mask)
13391339

1340-
joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype)
1340+
[rv_shape] = constant_fold([rv.shape], raise_not_constant=False)
1341+
joined_rv = pt.empty(rv_shape, dtype=rv.type.dtype)
13411342
joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv)
13421343
joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv)
13431344

pymc/pytensorf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
RandomGeneratorSharedVariable,
5757
RandomStateSharedVariable,
5858
)
59-
from pytensor.tensor.rewriting.basic import topo_constant_folding
6059
from pytensor.tensor.rewriting.shape import ShapeFeature
6160
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
6261
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -1015,7 +1014,8 @@ def constant_fold(
10151014
"""
10161015
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
10171016

1018-
folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs
1017+
# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
1018+
folded_xs = rewrite_graph(fg).outputs
10191019

10201020
if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
10211021
raise NotConstantValueError

tests/distributions/test_distribution.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,3 +1061,17 @@ def test_wrong_mask(self):
10611061
invalid_mask = np.zeros((1, 5), dtype=bool)
10621062
with pytest.raises(ValueError, match="mask can't have more dims than rv"):
10631063
create_partial_observed_rv(rv, invalid_mask)
1064+
1065+
@pytest.mark.filterwarnings("error")
1066+
def test_default_updates(self):
1067+
mask = np.array([True, True, False])
1068+
rv = pm.Normal.dist(shape=(3,))
1069+
(obs_rv, _), (unobs_rv, _), joined_rv = create_partial_observed_rv(rv, mask)
1070+
1071+
draws_obs_rv, draws_unobs_rv, draws_joined_rv = pm.draw(
1072+
[obs_rv, unobs_rv, joined_rv], draws=2
1073+
)
1074+
1075+
assert np.all(draws_obs_rv[0] != draws_obs_rv[1])
1076+
assert np.all(draws_unobs_rv[0] != draws_unobs_rv[1])
1077+
assert np.all(draws_joined_rv[0] != draws_joined_rv[1])

tests/sampling/test_jax.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import pymc as pm
3232

33+
from pymc import ImputationWarning
3334
from pymc.distributions.multivariate import PosDefMatrix
3435
from pymc.sampling.jax import (
3536
_get_batched_jittered_initial_points,
@@ -459,3 +460,14 @@ def test_idata_contains_stats(sampler_name: str):
459460
for stat_var, stat_var_dims in stat_vars.items():
460461
assert stat_var in stats.variables
461462
assert stats.get(stat_var).values.shape == stat_var_dims
463+
464+
465+
def test_sample_partially_observed():
466+
with pm.Model() as m:
467+
with pytest.warns(ImputationWarning):
468+
x = pm.Normal("x", observed=np.array([0, 1, np.nan]))
469+
idata = pm.sample(nuts_sampler="numpyro", chains=1, draws=10, tune=10)
470+
471+
assert idata.observed_data["x_observed"].shape == (2,)
472+
assert idata.posterior["x_unobserved"].shape == (1, 10, 1)
473+
assert idata.posterior["x"].shape == (1, 10, 3)

0 commit comments

Comments
 (0)