4
4
import pytensor .tensor as pt
5
5
6
6
from pymc .distributions import Bernoulli , Categorical , DiscreteUniform , SymbolicRandomVariable
7
- from pymc .logprob .abstract import _logprob
7
+ from pymc .logprob .abstract import _logprob , MeasurableOp
8
8
from pymc .logprob .basic import conditional_logp , logp
9
9
from pymc .pytensorf import constant_fold
10
+ from pytensor .compile .builders import OpFromGraph
10
11
from pytensor .compile .mode import Mode
11
- from pytensor .graph import vectorize_graph
12
+ from pytensor .graph import vectorize_graph , Op
12
13
from pytensor .graph .replace import clone_replace , graph_replace
13
14
from pytensor .scan import map as scan_map
14
15
from pytensor .scan import scan
17
18
from pymc_experimental .distributions import DiscreteMarkovChain
18
19
19
20
20
- class MarginalRV (SymbolicRandomVariable ):
21
+ class MarginalRV (OpFromGraph , MeasurableOp ):
21
22
"""Base class for Marginalized RVs"""
22
23
24
+ def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
25
+ self .dims_connections = dims_connections
26
+ super ().__init__ (* args , ** kwargs )
23
27
24
- class FiniteDiscreteMarginalRV (MarginalRV ):
25
- """Base class for Finite Discrete Marginalized RVs"""
26
28
29
+ class MarginalFiniteDiscreteRV (MarginalRV ):
30
+ """Base class for Marginalized Finite Discrete RVs"""
27
31
28
- class DiscreteMarginalMarkovChainRV (MarginalRV ):
29
- """Base class for Discrete Marginal Markov Chain RVs"""
32
+
33
+ class MarginalDiscreteMarkovChainRV (MarginalRV ):
34
+ """Base class for Marginalized Discrete Markov Chain RVs"""
30
35
31
36
32
37
def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
@@ -48,24 +53,69 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
48
53
raise NotImplementedError (f"Cannot compute domain for op { op } " )
49
54
50
55
51
- def _add_reduce_batch_dependent_logps (
52
- marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
53
- ):
54
- """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
56
+ def _reduce_batch_dependent_logps (
57
+ marginalized_op : MarginalRV ,
58
+ marginalized_logp : TensorVariable ,
59
+ dependent_ops : Sequence [Op ],
60
+ dependent_logps : Sequence [TensorVariable ],
61
+ ) -> TensorVariable :
62
+ """Combine the logps of dependent RVs with the marginalized logp.
63
+
64
+ This requires reducing extra batch dims and transposing when they are not aligned.
65
+
66
+ idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
67
+ pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
68
+ pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
69
+
70
+ marginalize(idx)
71
+ dims_connections = [(1, 0, None), (None, 0, 1)]
72
+ """
73
+
74
+ dims_connections = marginalized_op .dims_connections
75
+
76
+ reduced_logps = [marginalized_logp ]
77
+ for dependent_op , dependent_logp , dims_connection in zip (dependent_ops , dependent_logps , dims_connections ):
78
+ if dependent_logp .type .ndim > 0 :
79
+ # Find which support axis implied by the MarginalRV need to be reduced
80
+ # Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs
81
+ if isinstance (dependent_op , MarginalRV ):
82
+ dep_dims_connection = dependent_op .dims_connections [0 ]
83
+ dep_supp_axes = {- i for i , dim in enumerate (reversed (dep_dims_connection ), start = 1 ) if dim == ()}
84
+ else :
85
+ # For vanilla RVs, the support axes are the last ndim_supp
86
+ dep_supp_axes = set (range (- dependent_op .ndim_supp , 0 ))
87
+
88
+ # Dependent RV support axes are already collapsed in the logp, so we ignore them
89
+ supp_axes = [
90
+ - i
91
+ for i , dim in enumerate (reversed (dims_connection ), start = 1 )
92
+ if (dim == () and - i not in dep_supp_axes )
93
+ ]
94
+
95
+ dependent_logp = dependent_logp .sum (supp_axes )
96
+ assert dependent_logp .type .ndim == marginalized_logp .type .ndim
55
97
56
- mbcast = marginalized_type .broadcastable
57
- reduced_logps = []
58
- for dependent_logp in dependent_logps :
59
- dbcast = dependent_logp .type .broadcastable
60
- dim_diff = len (dbcast ) - len (mbcast )
61
- mbcast_aligned = mbcast + (True ,) * dim_diff
62
- vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , dbcast )) if m and not v ]
63
- reduced_logps .append (dependent_logp .sum (vbcast_axis ))
64
- return pt .add (* reduced_logps )
98
+ # Finally, we need to align the dependent logp batch dimensions with the marginalized logp
99
+ dims_alignment = [dim [0 ] for dim in dims_connection if dim != ()]
100
+ dependent_logp = dependent_logp .transpose (* dims_alignment )
65
101
102
+ reduced_logps .append (dependent_logp )
66
103
67
- @_logprob .register (FiniteDiscreteMarginalRV )
68
- def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
104
+ reduced_logp = pt .add (* reduced_logps )
105
+
106
+ if reduced_logp .type .ndim > 0 :
107
+ # Transpose reduced logp into the direction of the first dependent RV
108
+ first_dims_alignment = [dim [0 ] for dim in dims_connections [0 ] if dim != ()]
109
+ reduced_logp = reduced_logp .transpose (* first_dims_alignment )
110
+
111
+ return reduced_logp
112
+
113
+
114
+ dummy_zero = pt .constant (0 , name = "dummy_zero" )
115
+
116
+
117
+ @_logprob .register (MarginalFiniteDiscreteRV )
118
+ def finite_discrete_marginal_rv_logp (op : MarginalFiniteDiscreteRV , values , * inputs , ** kwargs ):
69
119
# Clone the inner RV graph of the Marginalized RV
70
120
marginalized_rvs_node = op .make_node (* inputs )
71
121
marginalized_rv , * inner_rvs = clone_replace (
@@ -81,8 +131,11 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
81
131
82
132
# Reduce logp dimensions corresponding to broadcasted variables
83
133
marginalized_logp = logps_dict .pop (marginalized_vv )
84
- joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
85
- marginalized_rv .type , logps_dict .values ()
134
+ joint_logp = _reduce_batch_dependent_logps (
135
+ marginalized_op = op ,
136
+ marginalized_logp = marginalized_logp ,
137
+ dependent_ops = [inner_rv .owner .op for inner_rv in inner_rvs ],
138
+ dependent_logps = [logps_dict [value ] for value in values ],
86
139
)
87
140
88
141
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -119,21 +172,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
119
172
mode = Mode ().including ("local_remove_check_parameter" ),
120
173
)
121
174
122
- joint_logps = pt .logsumexp (joint_logps , axis = 0 )
175
+ joint_logp = pt .logsumexp (joint_logps , axis = 0 )
123
176
124
177
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
125
- return joint_logps , * (pt . constant ( 0 ) ,) * (len (values ) - 1 )
178
+ return joint_logp , * (( dummy_zero ,) * (len (values ) - 1 ) )
126
179
127
180
128
- @_logprob .register (DiscreteMarginalMarkovChainRV )
181
+ @_logprob .register (MarginalDiscreteMarkovChainRV )
129
182
def marginal_hmm_logp (op , values , * inputs , ** kwargs ):
130
183
marginalized_rvs_node = op .make_node (* inputs )
131
- inner_rvs = clone_replace (
184
+ chain_rv , * dependent_rvs = clone_replace (
132
185
op .inner_outputs ,
133
186
replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
134
187
)
135
188
136
- chain_rv , * dependent_rvs = inner_rvs
137
189
P , n_steps_ , init_dist_ , rng = chain_rv .owner .inputs
138
190
domain = pt .arange (P .shape [- 1 ], dtype = "int32" )
139
191
@@ -149,9 +201,11 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
149
201
150
202
# Reduce and add the batch dims beyond the chain dimension
151
203
reduced_logp_emissions = _add_reduce_batch_dependent_logps (
204
+ init_logp ,
152
205
chain_rv .type , logp_emissions_dict .values ()
153
206
)
154
207
208
+
155
209
# Add a batch dimension for the domain of the chain
156
210
chain_shape = constant_fold (tuple (chain_rv .shape ))
157
211
batch_chain_value = pt .moveaxis (pt .full ((* chain_shape , domain .size ), domain ), - 1 , 0 )
@@ -188,7 +242,9 @@ def step_alpha(logp_emission, log_alpha, log_P):
188
242
# Final logp is just the sum of the last scan state
189
243
joint_logp = pt .logsumexp (log_alpha_seq [- 1 ], axis = 0 )
190
244
245
+ # TODO: Transpose into shape of first emission
246
+
191
247
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
192
- # return is the joint probability of everything together, but PyMC still expects one logp for each one .
193
- dummy_logps = (pt . constant ( 0 ), ) * (len (values ) - 1 )
248
+ # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream .
249
+ dummy_logps = (dummy_zero ) * (len (values ) - 1 )
194
250
return joint_logp , * dummy_logps
0 commit comments