@@ -54,12 +54,11 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
54
54
55
55
56
56
def _reduce_batch_dependent_logps (
57
- marginalized_op : MarginalRV ,
58
- marginalized_logp : TensorVariable ,
57
+ dependent_dims_connections : Sequence [tuple [int | None , ...]],
59
58
dependent_ops : Sequence [Op ],
60
59
dependent_logps : Sequence [TensorVariable ],
61
60
) -> TensorVariable :
62
- """Combine the logps of dependent RVs with the marginalized logp.
61
+ """Combine the logps of dependent RVs and align them with the marginalized logp.
63
62
64
63
This requires reducing extra batch dims and transposing when they are not aligned.
65
64
@@ -68,13 +67,14 @@ def _reduce_batch_dependent_logps(
68
67
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
69
68
70
69
marginalize(idx)
71
- dims_connections = [(1, 0, None), (None, 0, 1)]
72
- """
73
70
74
- dims_connections = marginalized_op .dims_connections
71
+ The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
72
+ which tells us we need to reduce the last axis of dep1 logp and the first of dep2,
73
+ as well as transpose the remaining axis of dep1 logp before adding the two elemwise.
74
+ """
75
75
76
- reduced_logps = [marginalized_logp ]
77
- for dependent_op , dependent_logp , dims_connection in zip (dependent_ops , dependent_logps , dims_connections ):
76
+ reduced_logps = []
77
+ for dependent_op , dependent_logp , dependent_dims_connection in zip (dependent_ops , dependent_logps , dependent_dims_connections ):
78
78
if dependent_logp .type .ndim > 0 :
79
79
# Find which support axis implied by the MarginalRV need to be reduced
80
80
# Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs
@@ -88,27 +88,27 @@ def _reduce_batch_dependent_logps(
88
88
# Dependent RV support axes are already collapsed in the logp, so we ignore them
89
89
supp_axes = [
90
90
- i
91
- for i , dim in enumerate (reversed (dims_connection ), start = 1 )
91
+ for i , dim in enumerate (reversed (dependent_dims_connection ), start = 1 )
92
92
if (dim == () and - i not in dep_supp_axes )
93
93
]
94
94
95
95
dependent_logp = dependent_logp .sum (supp_axes )
96
- assert dependent_logp .type .ndim == marginalized_logp .type .ndim
97
96
98
97
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
99
- dims_alignment = [dim [0 ] for dim in dims_connection if dim != ()]
98
+ dims_alignment = [dim [0 ] for dim in dependent_dims_connection if dim != ()]
100
99
dependent_logp = dependent_logp .transpose (* dims_alignment )
101
100
102
101
reduced_logps .append (dependent_logp )
103
102
104
103
reduced_logp = pt .add (* reduced_logps )
104
+ return reduced_logp
105
105
106
- if reduced_logp .type .ndim > 0 :
106
+ def _align_logp_with_dims (dims : tuple [tuple [int , None ]], logp : TensorVariable ) -> TensorVariable :
107
+ if logp .type .ndim > 0 :
107
108
# Transpose reduced logp into the direction of the first dependent RV
108
- first_dims_alignment = [dim [0 ] for dim in dims_connections [0 ] if dim != ()]
109
- reduced_logp = reduced_logp .transpose (* first_dims_alignment )
110
-
111
- return reduced_logp
109
+ dims_alignment = [dim [0 ] for dim in dims if dim != ()]
110
+ logp = logp .transpose (* dims_alignment )
111
+ return logp
112
112
113
113
114
114
dummy_zero = pt .constant (0 , name = "dummy_zero" )
@@ -131,9 +131,8 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
131
131
132
132
# Reduce logp dimensions corresponding to broadcasted variables
133
133
marginalized_logp = logps_dict .pop (marginalized_vv )
134
- joint_logp = _reduce_batch_dependent_logps (
135
- marginalized_op = op ,
136
- marginalized_logp = marginalized_logp ,
134
+ joint_logp = marginalized_logp + _reduce_batch_dependent_logps (
135
+ dependent_dims_connections = op .dims_connections ,
137
136
dependent_ops = [inner_rv .owner .op for inner_rv in inner_rvs ],
138
137
dependent_logps = [logps_dict [value ] for value in values ],
139
138
)
@@ -174,8 +173,12 @@ def logp_fn(marginalized_rv_const, *non_sequences):
174
173
175
174
joint_logp = pt .logsumexp (joint_logps , axis = 0 )
176
175
176
+ # Align logp with non-collapsed batch dimensions of first RV
177
+ joint_logp = _align_logp_with_dims (dims = op .dims_connections [0 ], logp = joint_logp )
178
+
177
179
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
178
- return joint_logp , * ((dummy_zero ,) * (len (values ) - 1 ))
180
+ dummy_logps = ((dummy_zero ,) * (len (values ) - 1 ))
181
+ return joint_logp , * dummy_logps
179
182
180
183
181
184
@_logprob .register (MarginalDiscreteMarkovChainRV )
@@ -200,8 +203,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
200
203
logp_emissions_dict = conditional_logp (dict (zip (dependent_rvs , values )))
201
204
202
205
# Reduce and add the batch dims beyond the chain dimension
203
- reduced_logp_emissions = _add_reduce_batch_dependent_logps (
204
- chain_rv .type , logp_emissions_dict .values ()
206
+ reduced_logp_emissions = _reduce_batch_dependent_logps (
207
+ dependent_dims_connections = op .dims_connections ,
208
+ dependent_ops = [dependent_rv .owner .op for dependent_rv in dependent_rvs ],
209
+ dependent_logps = [logp_emissions_dict [value ] for value in values ],
205
210
)
206
211
207
212
# Add a batch dimension for the domain of the chain
@@ -240,9 +245,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
240
245
# Final logp is just the sum of the last scan state
241
246
joint_logp = pt .logsumexp (log_alpha_seq [- 1 ], axis = 0 )
242
247
243
- # TODO: Transpose into shape of first emission
248
+ # Align logp with non-collapsed batch dimensions of first RV
249
+ remaining_dims_first_emission = list (op .dims_connections [0 ])
250
+ # The last dim of chain_rv was removed when computing the logp
251
+ remaining_dims_first_emission .remove ((chain_rv .type .ndim - 1 ,))
252
+ joint_logp = _align_logp_with_dims (remaining_dims_first_emission , joint_logp )
244
253
245
254
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
246
255
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
247
- dummy_logps = (dummy_zero ) * (len (values ) - 1 )
256
+ dummy_logps = (dummy_zero , ) * (len (values ) - 1 )
248
257
return joint_logp , * dummy_logps
0 commit comments