diff --git a/pymc_experimental/model/transforms/autoreparam.py b/pymc_experimental/model/transforms/autoreparam.py index bb3996459..a27fc2dd4 100644 --- a/pymc_experimental/model/transforms/autoreparam.py +++ b/pymc_experimental/model/transforms/autoreparam.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass from functools import singledispatch from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -8,7 +9,6 @@ import pytensor.tensor as pt import scipy.special from pymc.distributions import SymbolicRandomVariable -from pymc.exceptions import NotConstantValueError from pymc.logprob.transforms import Transform from pymc.model.fgraph import ( ModelDeterministic, @@ -19,10 +19,12 @@ model_from_fgraph, model_named, ) -from pymc.pytensorf import constant_fold, toposort_replace +from pymc.pytensorf import toposort_replace from pytensor.graph.basic import Apply, Variable from pytensor.tensor.random.op import RandomVariable +_log = logging.getLogger("pmx") + @dataclass class VIP: @@ -174,15 +176,19 @@ def vip_reparam_node( ) -> Tuple[ModelDeterministic, ModelNamed]: if not isinstance(node.op, RandomVariable | SymbolicRandomVariable): raise TypeError("Op should be RandomVariable type") - rv = node.default_output() - try: - [rv_shape] = constant_fold([rv.shape]) - except NotConstantValueError: - raise ValueError("Size should be static for autoreparametrization.") + # FIXME: This is wrong when size is None + _, size, *_ = node.inputs + eval_size = size.eval(mode="FAST_COMPILE") + if eval_size is not None: + rv_shape = tuple(eval_size) + else: + rv_shape = () + lam_name = f"{name}::lam_logit__" + _log.debug(f"Creating {lam_name} with shape of {rv_shape}") logit_lam_ = pytensor.shared( np.zeros(rv_shape), shape=rv_shape, - name=f"{name}::lam_logit__", + name=lam_name, ) logit_lam = model_named(logit_lam_, *dims) lam = pt.sigmoid(logit_lam) diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index 1d2173066..cb7176f70 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -7,20 +7,21 @@ @pytest.fixture def model_c(): - with pm.Model() as mod: + # TODO: Restructure tests so they check one dist at a time + with pm.Model(coords=dict(a=range(5))) as mod: m = pm.Normal("m") s = pm.LogNormal("s") - pm.Normal("g", m, s, shape=5) + pm.Normal("g", m, s, dims="a") pm.Exponential("e", scale=s, shape=7) return mod @pytest.fixture def model_nc(): - with pm.Model() as mod: + with pm.Model(coords=dict(a=range(5))) as mod: m = pm.Normal("m") s = pm.LogNormal("s") - pm.Deterministic("g", pm.Normal("z", shape=5) * s + m) + pm.Deterministic("g", pm.Normal("z", dims="a") * s + m) pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s) return mod @@ -102,3 +103,29 @@ def test_set_truncate(model_c: pm.Model): vip.truncate_lambda(g=0.2) np.testing.assert_allclose(vip.get_lambda()["g"], 1) np.testing.assert_allclose(vip.get_lambda()["m"], 0.9) + + +@pytest.mark.xfail(reason="FIX shape computation for lambda") +def test_lambda_shape(): + with pm.Model(coords=dict(a=[1, 2])) as model: + b1 = pm.Normal("b1", dims="a") + b2 = pm.Normal("b2", shape=2) + b3 = pm.Normal("b3", size=2) + b4 = pm.Normal("b4", np.asarray([1, 2])) + model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"]) + lams = vip.get_lambda() + for v in ["b1", "b2", "b3", "b4"]: + assert lams[v].shape == (2,), v + + +@pytest.mark.xfail(reason="FIX shape computation for lambda") +def test_lambda_shape_transformed_1d(): + with pm.Model(coords=dict(a=[1, 2])) as model: + b1 = pm.Exponential("b1", 1, dims="a") + b2 = pm.Exponential("b2", 1, shape=2) + b3 = pm.Exponential("b3", 1, size=2) + b4 = pm.Exponential("b4", np.asarray([1, 2])) + model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"]) + lams = vip.get_lambda() + for v in ["b1", "b2", "b3", "b4"]: + assert lams[v].shape == (2,), v