Skip to content

Commit 94c04a1

Browse files
Merge branch 'main' into pre-commit-update
2 parents e6bf0e3 + 7ce4ac8 commit 94c04a1

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
import logging
23
from dataclasses import dataclass
34
from functools import singledispatch
45

@@ -9,7 +10,6 @@
910
import scipy.special
1011

1112
from pymc.distributions import SymbolicRandomVariable
12-
from pymc.exceptions import NotConstantValueError
1313
from pymc.logprob.transforms import Transform
1414
from pymc.model.fgraph import (
1515
ModelDeterministic,
@@ -20,10 +20,12 @@
2020
model_from_fgraph,
2121
model_named,
2222
)
23-
from pymc.pytensorf import constant_fold, toposort_replace
23+
from pymc.pytensorf import toposort_replace
2424
from pytensor.graph.basic import Apply, Variable
2525
from pytensor.tensor.random.op import RandomVariable
2626

27+
_log = logging.getLogger("pmx")
28+
2729

2830
@dataclass
2931
class VIP:
@@ -175,15 +177,19 @@ def vip_reparam_node(
175177
) -> tuple[ModelDeterministic, ModelNamed]:
176178
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
177179
raise TypeError("Op should be RandomVariable type")
178-
rv = node.default_output()
179-
try:
180-
[rv_shape] = constant_fold([rv.shape])
181-
except NotConstantValueError:
182-
raise ValueError("Size should be static for autoreparametrization.")
180+
# FIXME: This is wrong when size is None
181+
_, size, *_ = node.inputs
182+
eval_size = size.eval(mode="FAST_COMPILE")
183+
if eval_size is not None:
184+
rv_shape = tuple(eval_size)
185+
else:
186+
rv_shape = ()
187+
lam_name = f"{name}::lam_logit__"
188+
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
183189
logit_lam_ = pytensor.shared(
184190
np.zeros(rv_shape),
185191
shape=rv_shape,
186-
name=f"{name}::lam_logit__",
192+
name=lam_name,
187193
)
188194
logit_lam = model_named(logit_lam_, *dims)
189195
lam = pt.sigmoid(logit_lam)

tests/model/transforms/test_autoreparam.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77

88
@pytest.fixture
99
def model_c():
10-
with pm.Model() as mod:
10+
# TODO: Restructure tests so they check one dist at a time
11+
with pm.Model(coords=dict(a=range(5))) as mod:
1112
m = pm.Normal("m")
1213
s = pm.LogNormal("s")
13-
pm.Normal("g", m, s, shape=5)
14+
pm.Normal("g", m, s, dims="a")
1415
pm.Exponential("e", scale=s, shape=7)
1516
return mod
1617

1718

1819
@pytest.fixture
1920
def model_nc():
20-
with pm.Model() as mod:
21+
with pm.Model(coords=dict(a=range(5))) as mod:
2122
m = pm.Normal("m")
2223
s = pm.LogNormal("s")
23-
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
24+
pm.Deterministic("g", pm.Normal("z", dims="a") * s + m)
2425
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
2526
return mod
2627

@@ -102,3 +103,29 @@ def test_set_truncate(model_c: pm.Model):
102103
vip.truncate_lambda(g=0.2)
103104
np.testing.assert_allclose(vip.get_lambda()["g"], 1)
104105
np.testing.assert_allclose(vip.get_lambda()["m"], 0.9)
106+
107+
108+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
109+
def test_lambda_shape():
110+
with pm.Model(coords=dict(a=[1, 2])) as model:
111+
b1 = pm.Normal("b1", dims="a")
112+
b2 = pm.Normal("b2", shape=2)
113+
b3 = pm.Normal("b3", size=2)
114+
b4 = pm.Normal("b4", np.asarray([1, 2]))
115+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
116+
lams = vip.get_lambda()
117+
for v in ["b1", "b2", "b3", "b4"]:
118+
assert lams[v].shape == (2,), v
119+
120+
121+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
122+
def test_lambda_shape_transformed_1d():
123+
with pm.Model(coords=dict(a=[1, 2])) as model:
124+
b1 = pm.Exponential("b1", 1, dims="a")
125+
b2 = pm.Exponential("b2", 1, shape=2)
126+
b3 = pm.Exponential("b3", 1, size=2)
127+
b4 = pm.Exponential("b4", np.asarray([1, 2]))
128+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
129+
lams = vip.get_lambda()
130+
for v in ["b1", "b2", "b3", "b4"]:
131+
assert lams[v].shape == (2,), v

0 commit comments

Comments
 (0)