5
5
6
6
from pymc import SymbolicRandomVariable
7
7
from pytensor .compile import SharedVariable
8
+ from pytensor .compile .builders import OpFromGraph
8
9
from pytensor .graph import Constant , Variable , ancestors
9
10
from pytensor .graph .basic import io_toposort
10
11
from pytensor .tensor import TensorType , TensorVariable
16
17
from pytensor .tensor .subtensor import AdvancedSubtensor , Subtensor , get_idx_list
17
18
from pytensor .tensor .type_other import NoneTypeT
18
19
19
- from pymc_extras .model .marginal .distributions import MarginalRV
20
-
21
20
22
21
def static_shape_ancestors (vars ):
23
22
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -63,7 +62,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
63
62
64
63
65
64
def get_support_axes (op ) -> tuple [tuple [int , ...], ...]:
66
- if isinstance (op , MarginalRV ):
65
+ if hasattr (op , "support_axes" ):
67
66
return op .support_axes
68
67
else :
69
68
# For vanilla RVs, the support axes are the last ndim_supp
@@ -146,7 +145,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
146
145
output_dims = tuple (None if i == "x" else input_dims [i ] for i in node .op .new_order )
147
146
var_dims [node .outputs [0 ]] = output_dims
148
147
149
- elif isinstance (node .op , MarginalRV ) or (
148
+ elif ( isinstance (node .op , OpFromGraph ) and hasattr ( node . op , "support_axes" ) ) or (
150
149
isinstance (node .op , SymbolicRandomVariable ) and node .op .extended_signature is None
151
150
):
152
151
# MarginalRV and SymbolicRandomVariables without signature are a wild-card,
@@ -160,7 +159,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
160
159
)
161
160
162
161
support_axes = iter (get_support_axes (op ))
163
- if isinstance (op , MarginalRV ):
162
+ if hasattr (op , "support_axes" ):
164
163
# The first output is the marginalized variable for which we don't compute support axes
165
164
support_axes = itertools .chain (((),), support_axes )
166
165
for i , (out , inner_out ) in enumerate (zip (node .outputs , inner_outputs )):
0 commit comments