Skip to content

Commit 82a1a24

Browse files
committed
Revert "refactor: Use infer_static_shape from pytensor"
This reverts commit a9f582c.
1 parent a9f582c commit 82a1a24

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
)
2222
from pymc.pytensorf import toposort_replace
2323
from pytensor.graph.basic import Apply, Variable
24-
from pytensor.tensor.basic import infer_static_shape
2524
from pytensor.tensor.random.op import RandomVariable
2625

2726
_log = logging.getLogger("pmx")
@@ -177,9 +176,12 @@ def vip_reparam_node(
177176
) -> Tuple[ModelDeterministic, ModelNamed]:
178177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
179178
raise TypeError("Op should be RandomVariable type")
180-
rv = node.default_output()
181-
rv_shape_t, _ = infer_static_shape(rv.shape)
182-
rv_shape = pt.as_tensor(rv_shape_t).eval(mode="FAST_COMPILE")
179+
_, size, *_ = node.inputs
180+
eval_size = size.eval(mode="FAST_COMPILE")
181+
if eval_size is not None:
182+
rv_shape = tuple(eval_size)
183+
else:
184+
rv_shape = ()
183185
lam_name = f"{name}::lam_logit__"
184186
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
185187
logit_lam_ = pytensor.shared(

0 commit comments

Comments
 (0)