File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
pymc_experimental/model/transforms Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change 21
21
)
22
22
from pymc .pytensorf import toposort_replace
23
23
from pytensor .graph .basic import Apply , Variable
24
- from pytensor .tensor .basic import infer_static_shape
25
24
from pytensor .tensor .random .op import RandomVariable
26
25
27
26
_log = logging .getLogger ("pmx" )
@@ -177,9 +176,12 @@ def vip_reparam_node(
177
176
) -> Tuple [ModelDeterministic , ModelNamed ]:
178
177
if not isinstance (node .op , RandomVariable | SymbolicRandomVariable ):
179
178
raise TypeError ("Op should be RandomVariable type" )
180
- rv = node .default_output ()
181
- rv_shape_t , _ = infer_static_shape (rv .shape )
182
- rv_shape = pt .as_tensor (rv_shape_t ).eval (mode = "FAST_COMPILE" )
179
+ _ , size , * _ = node .inputs
180
+ eval_size = size .eval (mode = "FAST_COMPILE" )
181
+ if eval_size is not None :
182
+ rv_shape = tuple (eval_size )
183
+ else :
184
+ rv_shape = ()
183
185
lam_name = f"{ name } ::lam_logit__"
184
186
_log .debug (f"Creating { lam_name } with shape of { rv_shape } " )
185
187
logit_lam_ = pytensor .shared (
You can’t perform that action at this time.
0 commit comments