1+ import warnings
2+
13from collections .abc import Sequence
24
35import numpy as np
46import pytensor .tensor as pt
57
68from pymc .distributions import Bernoulli , Categorical , DiscreteUniform
9+ from pymc .distributions .distribution import _support_point , support_point
710from pymc .logprob .abstract import MeasurableOp , _logprob
811from pymc .logprob .basic import conditional_logp , logp
9- from pymc .pytensorf import constant_fold
12+ from pymc .model .fgraph import ModelVar
13+ from pymc .pytensorf import constant_fold , StringType
1014from pytensor import Variable
1115from pytensor .compile .builders import OpFromGraph
1216from pytensor .compile .mode import Mode
13- from pytensor .graph import Op , vectorize_graph
17+ from pytensor .graph import FunctionGraph , Op , vectorize_graph
18+ from pytensor .graph .basic import equal_computations , Apply
1419from pytensor .graph .replace import clone_replace , graph_replace
1520from pytensor .scan import map as scan_map
1621from pytensor .scan import scan
1722from pytensor .tensor import TensorVariable
23+ from pytensor .tensor .random .type import RandomType
1824
1925from pymc_extras .distributions import DiscreteMarkovChain
2026
2127
2228class MarginalRV (OpFromGraph , MeasurableOp ):
2329 """Base class for Marginalized RVs"""
2430
25- def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
31+ def __init__ (self , * args , dims_connections : tuple [tuple [int | None ], ...], dims : tuple [ Variable , ... ], ** kwargs ) -> None :
2632 self .dims_connections = dims_connections
33+ self .dims = dims
2734 super ().__init__ (* args , ** kwargs )
2835
2936 @property
@@ -43,6 +50,74 @@ def support_axes(self) -> tuple[tuple[int]]:
4350 )
4451 return tuple (support_axes_vars )
4552
53+ def __eq__ (self , other ):
54+ # Just to allow easy testing of equivalent models,
55+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
56+ if type (self ) is not type (other ):
57+ return False
58+
59+ return equal_computations (
60+ self .inner_outputs ,
61+ other .inner_outputs ,
62+ self .inner_inputs ,
63+ other .inner_inputs ,
64+ )
65+
66+ def __hash__ (self ):
67+ # Just to allow easy testing of equivalent models,
68+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
69+ return hash ((type (self ), len (self .inner_inputs ), len (self .inner_outputs )))
70+
71+
72+ @_support_point .register
73+ def support_point_marginal_rv (op : MarginalRV , rv , * inputs ):
74+ """Support point for a marginalized RV.
75+
76+ The support point of a marginalized RV is the support point of the inner RV,
77+ conditioned on the marginalized RV taking its support point.
78+ """
79+ outputs = rv .owner .outputs
80+
81+ inner_rv = op .inner_outputs [outputs .index (rv )]
82+ marginalized_inner_rv , * other_dependent_inner_rvs = (
83+ out
84+ for out in op .inner_outputs
85+ if out is not inner_rv and not isinstance (out .type , RandomType )
86+ )
87+
88+ # Replace references to inner rvs by the dummy variables (including the marginalized RV)
89+ # This is necessary because the inner RVs may depend on each other
90+ marginalized_inner_rv_dummy = marginalized_inner_rv .clone ()
91+ other_dependent_inner_rv_to_dummies = {
92+ inner_rv : inner_rv .clone () for inner_rv in other_dependent_inner_rvs
93+ }
94+ inner_rv = clone_replace (
95+ inner_rv ,
96+ replace = {marginalized_inner_rv : marginalized_inner_rv_dummy }
97+ | other_dependent_inner_rv_to_dummies ,
98+ )
99+
100+ # Get support point of inner RV and marginalized RV
101+ inner_rv_support_point = support_point (inner_rv )
102+ marginalized_inner_rv_support_point = support_point (marginalized_inner_rv )
103+
104+ replacements = [
105+ # Replace the marginalized RV dummy by its support point
106+ (marginalized_inner_rv_dummy , marginalized_inner_rv_support_point ),
107+ # Replace other dependent RVs dummies by the respective outer outputs.
108+ # PyMC will replace them by their support points later
109+ * (
110+ (v , outputs [op .inner_outputs .index (k )])
111+ for k , v in other_dependent_inner_rv_to_dummies .items ()
112+ ),
113+ # Replace outer input RVs
114+ * zip (op .inner_inputs , inputs ),
115+ ]
116+ fgraph = FunctionGraph (outputs = [inner_rv_support_point ], clone = False )
117+ fgraph .replace_all (replacements , import_missing = True )
118+ [rv_support_point ] = fgraph .outputs
119+ return rv_support_point
120+
46121
47122class MarginalFiniteDiscreteRV (MarginalRV ):
48123 """Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +207,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132207 Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133208 the inner graph.
134209 """
135- return clone_replace (
210+ return graph_replace (
136211 op .inner_outputs ,
137212 replace = tuple (zip (op .inner_inputs , inputs )),
213+ strict = False ,
138214 )
139215
140216
217+ class NonSeparableLogpWarning (UserWarning ):
218+ pass
219+
220+
221+ def warn_non_separable_logp (values ):
222+ if len (values ) > 1 :
223+ warnings .warn (
224+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
225+ f"Their joint logp terms will be assigned to the first value: { values [0 ]} ." ,
226+ NonSeparableLogpWarning ,
227+ stacklevel = 2 ,
228+ )
229+
230+
141231DUMMY_ZERO = pt .constant (0 , name = "dummy_zero" )
142232
143233
@@ -199,6 +289,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199289 # Align logp with non-collapsed batch dimensions of first RV
200290 joint_logp = align_logp_dims (dims = op .dims_connections [0 ], logp = joint_logp )
201291
292+ warn_non_separable_logp (values )
202293 # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203294 dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
204295 return joint_logp , * dummy_logps
@@ -272,5 +363,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272363
273364 # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274365 # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
366+ warn_non_separable_logp (values )
275367 dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
276368 return joint_logp , * dummy_logps
0 commit comments