Skip to content

Commit 5488666

Browse files
committed
Fix circular import
1 parent bf453ff commit 5488666

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
from pytensor.tensor import TensorVariable
1818

1919
from pymc_extras.distributions import DiscreteMarkovChain
20-
from pymc_extras.model.marginal.graph_analysis import get_support_axes
20+
21+
22+
def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
23+
if hasattr(op, "support_axes"):
24+
return op.support_axes
25+
else:
26+
# For vanilla RVs, the support axes are the last ndim_supp
27+
return (tuple(range(-op.ndim_supp, 0)),)
2128

2229

2330
class MarginalRV(OpFromGraph, MeasurableOp):

pymc_extras/model/marginal/graph_analysis.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pymc import SymbolicRandomVariable
77
from pytensor.compile import SharedVariable
8+
from pytensor.compile.builders import OpFromGraph
89
from pytensor.graph import Constant, Variable, ancestors
910
from pytensor.graph.basic import io_toposort
1011
from pytensor.tensor import TensorType, TensorVariable
@@ -16,8 +17,6 @@
1617
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
1718
from pytensor.tensor.type_other import NoneTypeT
1819

19-
from pymc_extras.model.marginal.distributions import MarginalRV
20-
2120

2221
def static_shape_ancestors(vars):
2322
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -63,7 +62,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
6362

6463

6564
def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
66-
if isinstance(op, MarginalRV):
65+
if hasattr(op, "support_axes"):
6766
return op.support_axes
6867
else:
6968
# For vanilla RVs, the support axes are the last ndim_supp
@@ -146,7 +145,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
146145
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
147146
var_dims[node.outputs[0]] = output_dims
148147

149-
elif isinstance(node.op, MarginalRV) or (
148+
elif (isinstance(node.op, OpFromGraph) and hasattr(node.op, "support_axes")) or (
150149
isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
151150
):
152151
# MarginalRV and SymbolicRandomVariables without signature are a wild-card,
@@ -160,7 +159,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
160159
)
161160

162161
support_axes = iter(get_support_axes(op))
163-
if isinstance(op, MarginalRV):
162+
if hasattr(op, "support_axes"):
164163
# The first output is the marginalized variable for which we don't compute support axes
165164
support_axes = itertools.chain(((),), support_axes)
166165
for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)):

0 commit comments

Comments
 (0)