55
66from pymc import SymbolicRandomVariable
77from pytensor .compile import SharedVariable
8+ from pytensor .compile .builders import OpFromGraph
89from pytensor .graph import Constant , Variable , ancestors
910from pytensor .graph .basic import io_toposort
1011from pytensor .tensor import TensorType , TensorVariable
1617from pytensor .tensor .subtensor import AdvancedSubtensor , Subtensor , get_idx_list
1718from pytensor .tensor .type_other import NoneTypeT
1819
19- from pymc_extras .model .marginal .distributions import MarginalRV
20-
2120
2221def static_shape_ancestors (vars ):
2322 """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -63,7 +62,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
6362
6463
6564def get_support_axes (op ) -> tuple [tuple [int , ...], ...]:
66- if isinstance (op , MarginalRV ):
65+ if hasattr (op , "support_axes" ):
6766 return op .support_axes
6867 else :
6968 # For vanilla RVs, the support axes are the last ndim_supp
@@ -146,7 +145,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
146145 output_dims = tuple (None if i == "x" else input_dims [i ] for i in node .op .new_order )
147146 var_dims [node .outputs [0 ]] = output_dims
148147
149- elif isinstance (node .op , MarginalRV ) or (
148+ elif ( isinstance (node .op , OpFromGraph ) and hasattr ( node . op , "support_axes" ) ) or (
150149 isinstance (node .op , SymbolicRandomVariable ) and node .op .extended_signature is None
151150 ):
152151 # MarginalRV and SymbolicRandomVariables without signature are a wild-card,
@@ -160,7 +159,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
160159 )
161160
162161 support_axes = iter (get_support_axes (op ))
163- if isinstance (op , MarginalRV ):
162+ if hasattr (op , "support_axes" ):
164163 # The first output is the marginalized variable for which we don't compute support axes
165164 support_axes = itertools .chain (((),), support_axes )
166165 for i , (out , inner_out ) in enumerate (zip (node .outputs , inner_outputs )):
0 commit comments