1
1
from collections.abc import Sequence
2
2
3
3
import numpy as np
4
+ import pytensor.tensor as pt
4
5
5
- from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp
6
- from pymc.logprob import conditional_logp
7
- from pymc.logprob.abstract import _logprob
6
+ from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7
+ from pymc.logprob.abstract import MeasurableOp, _logprob
8
+ from pymc.logprob.basic import conditional_logp, logp
8
9
from pymc.pytensorf import constant_fold
9
- from pytensor import Mode, clone_replace, graph_replace, scan
10
- from pytensor import map as scan_map
11
- from pytensor import tensor as pt
12
- from pytensor.graph import vectorize_graph
13
- from pytensor.tensor import TensorType, TensorVariable
10
+ from pytensor import Variable
11
+ from pytensor.compile.builders import OpFromGraph
12
+ from pytensor.compile.mode import Mode
13
+ from pytensor.graph import Op, vectorize_graph
14
+ from pytensor.graph.replace import clone_replace, graph_replace
15
+ from pytensor.scan import map as scan_map
16
+ from pytensor.scan import scan
17
+ from pytensor.tensor import TensorVariable
14
18
15
19
from pymc_experimental.distributions import DiscreteMarkovChain
16
20
17
21
18
- class MarginalRV(SymbolicRandomVariable ):
22
+ class MarginalRV(OpFromGraph, MeasurableOp ):
19
23
"""Base class for Marginalized RVs"""
20
24
25
+ def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
26
+ self.dims_connections = dims_connections
27
+ super().__init__(*args, **kwargs)
21
28
22
- class FiniteDiscreteMarginalRV(MarginalRV):
23
- """Base class for Finite Discrete Marginalized RVs"""
29
+ @property
30
+ def support_axes(self) -> tuple[tuple[int]]:
31
+ """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
32
+ marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
33
+ support_axes_vars = []
34
+ for dims_connection in self.dims_connections:
35
+ ndim = len(dims_connection)
36
+ marginalized_supp_axes = ndim - marginalized_ndim_supp
37
+ support_axes_vars.append(
38
+ tuple(
39
+ -i
40
+ for i, dim in enumerate(reversed(dims_connection), start=1)
41
+ if (dim is None or dim > marginalized_supp_axes)
42
+ )
43
+ )
44
+ return tuple(support_axes_vars)
24
45
25
46
26
- class DiscreteMarginalMarkovChainRV(MarginalRV):
27
- """Base class for Discrete Marginal Markov Chain RVs"""
47
+ class MarginalFiniteDiscreteRV(MarginalRV):
48
+ """Base class for Marginalized Finite Discrete RVs"""
49
+
50
+
51
+ class MarginalDiscreteMarkovChainRV(MarginalRV):
52
+ """Base class for Marginalized Discrete Markov Chain RVs"""
28
53
29
54
30
55
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
@@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
34
59
return (0, 1)
35
60
elif isinstance(op, Categorical):
36
61
[p_param] = dist_params
37
- return tuple(range(pt.get_vector_length(p_param)))
62
+ [p_param_length] = constant_fold([p_param.shape[-1]])
63
+ return tuple(range(p_param_length))
38
64
elif isinstance(op, DiscreteUniform):
39
65
lower, upper = constant_fold(dist_params)
40
66
return tuple(np.arange(lower, upper + 1))
@@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
45
71
raise NotImplementedError(f"Cannot compute domain for op {op}")
46
72
47
73
48
- def _add_reduce_batch_dependent_logps(
49
- marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
50
- ):
51
- """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
74
+ def reduce_batch_dependent_logps(
75
+ dependent_dims_connections: Sequence[tuple[int | None, ...]],
76
+ dependent_ops: Sequence[Op],
77
+ dependent_logps: Sequence[TensorVariable],
78
+ ) -> TensorVariable:
79
+ """Combine the logps of dependent RVs and align them with the marginalized logp.
80
+
81
+ This requires reducing extra batch dims and transposing when they are not aligned.
82
+
83
+ idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
84
+ pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
85
+ pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
86
+
87
+ marginalize(idx)
88
+
89
+ The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
90
+ which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp,
91
+ as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
92
+
93
+ """
94
+ from pymc_experimental.model.marginal.graph_analysis import get_support_axes
52
95
53
- mbcast = marginalized_type.broadcastable
54
96
reduced_logps = []
55
- for dependent_logp in dependent_logps:
56
- dbcast = dependent_logp.type.broadcastable
57
- dim_diff = len(dbcast) - len(mbcast)
58
- mbcast_aligned = (True,) * dim_diff + mbcast
59
- vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
60
- reduced_logps.append(dependent_logp.sum(vbcast_axis) )
61
- return pt.add(*reduced_logps)
97
+ for dependent_op, dependent_logp, dependent_dims_connection in zip(
98
+ dependent_ops, dependent_logps, dependent_dims_connections
99
+ ):
100
+ if dependent_logp.type.ndim > 0:
101
+ # Find which support axis implied by the MarginalRV need to be reduced
102
+ # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs )
103
+ dep_supp_axes = get_support_axes(dependent_op)[0]
62
104
105
+ # Dependent RV support axes are already collapsed in the logp, so we ignore them
106
+ supp_axes = [
107
+ -i
108
+ for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
109
+ if (dim is None and -i not in dep_supp_axes)
110
+ ]
111
+ dependent_logp = dependent_logp.sum(supp_axes)
63
112
64
- @_logprob.register(FiniteDiscreteMarginalRV)
65
- def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
66
- # Clone the inner RV graph of the Marginalized RV
67
- marginalized_rvs_node = op.make_node(*inputs)
68
- marginalized_rv, *inner_rvs = clone_replace(
113
+ # Finally, we need to align the dependent logp batch dimensions with the marginalized logp
114
+ dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
115
+ dependent_logp = dependent_logp.transpose(*dims_alignment)
116
+
117
+ reduced_logps.append(dependent_logp)
118
+
119
+ reduced_logp = pt.add(*reduced_logps)
120
+ return reduced_logp
121
+
122
+
123
+ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
124
+ """Align the logp with the order specified in dims."""
125
+ dims_alignment = [dim for dim in dims if dim is not None]
126
+ return logp.transpose(*dims_alignment)
127
+
128
+
129
+ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130
+ """Inline the inner graph (outputs) of an OpFromGraph Op.
131
+
132
+ Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
+ the inner graph.
134
+ """
135
+ return clone_replace(
69
136
op.inner_outputs,
70
- replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node. inputs)} ,
137
+ replace=tuple( zip(op.inner_inputs, inputs)) ,
71
138
)
72
139
140
+
141
+ DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142
+
143
+
144
+ @_logprob.register(MarginalFiniteDiscreteRV)
145
+ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
146
+ # Clone the inner RV graph of the Marginalized RV
147
+ marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs)
148
+
73
149
# Obtain the joint_logp graph of the inner RV graph
74
150
inner_rv_values = dict(zip(inner_rvs, values))
75
151
marginalized_vv = marginalized_rv.clone()
@@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
78
154
79
155
# Reduce logp dimensions corresponding to broadcasted variables
80
156
marginalized_logp = logps_dict.pop(marginalized_vv)
81
- joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
82
- marginalized_rv.type, logps_dict.values()
157
+ joint_logp = marginalized_logp + reduce_batch_dependent_logps(
158
+ dependent_dims_connections=op.dims_connections,
159
+ dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
160
+ dependent_logps=[logps_dict[value] for value in values],
83
161
)
84
162
85
163
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
116
194
mode=Mode().including("local_remove_check_parameter"),
117
195
)
118
196
119
- joint_logps = pt.logsumexp(joint_logps, axis=0)
197
+ joint_logp = pt.logsumexp(joint_logps, axis=0)
198
+
199
+ # Align logp with non-collapsed batch dimensions of first RV
200
+ joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
120
201
121
202
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
122
- return joint_logps, *(pt.constant(0),) * (len(values) - 1)
203
+ dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
204
+ return joint_logp, *dummy_logps
123
205
124
206
125
- @_logprob.register(DiscreteMarginalMarkovChainRV )
207
+ @_logprob.register(MarginalDiscreteMarkovChainRV )
126
208
def marginal_hmm_logp(op, values, *inputs, **kwargs):
127
- marginalized_rvs_node = op.make_node(*inputs)
128
- inner_rvs = clone_replace(
129
- op.inner_outputs,
130
- replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
131
- )
209
+ chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs)
132
210
133
- chain_rv, *dependent_rvs = inner_rvs
134
211
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
135
212
domain = pt.arange(P.shape[-1], dtype="int32")
136
213
@@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
145
222
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
146
223
147
224
# Reduce and add the batch dims beyond the chain dimension
148
- reduced_logp_emissions = _add_reduce_batch_dependent_logps(
149
- chain_rv.type, logp_emissions_dict.values()
225
+ reduced_logp_emissions = reduce_batch_dependent_logps(
226
+ dependent_dims_connections=op.dims_connections,
227
+ dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
228
+ dependent_logps=[logp_emissions_dict[value] for value in values],
150
229
)
151
230
152
231
# Add a batch dimension for the domain of the chain
@@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
185
264
# Final logp is just the sum of the last scan state
186
265
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
187
266
267
+ # Align logp with non-collapsed batch dimensions of first RV
268
+ remaining_dims_first_emission = list(op.dims_connections[0])
269
+ # The last dim of chain_rv was removed when computing the logp
270
+ remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
271
+ joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
272
+
188
273
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
189
- # return is the joint probability of everything together, but PyMC still expects one logp for each one .
190
- dummy_logps = (pt.constant(0) ,) * (len(values) - 1)
274
+ # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream .
275
+ dummy_logps = (DUMMY_ZERO ,) * (len(values) - 1)
191
276
return joint_logp, *dummy_logps
0 commit comments