Skip to content

Commit 361fd54

Browse files
committed
Support more kinds of marginalization via dim analysis
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
1 parent 09dd147 commit 361fd54

File tree

9 files changed

+1275
-549
lines changed

9 files changed

+1275
-549
lines changed

pymc_experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_experimental import distributions, gp, statespace, utils
1717
from pymc_experimental.inference.fit import fit
18-
from pymc_experimental.model.marginal_model import MarginalModel
18+
from pymc_experimental.model.marginal.marginal_model import MarginalModel
1919
from pymc_experimental.model.model_api import as_model
2020
from pymc_experimental.version import __version__
2121

pymc_experimental/model/marginal/__init__.py

Whitespace-only changes.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
from collections.abc import Sequence
2+
3+
import numpy as np
4+
import pytensor.tensor as pt
5+
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
7+
from pymc.logprob.abstract import MeasurableOp, _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
9+
from pymc.pytensorf import constant_fold
10+
from pytensor.compile.builders import OpFromGraph
11+
from pytensor.compile.mode import Mode
12+
from pytensor.graph import Op, vectorize_graph
13+
from pytensor.graph.replace import clone_replace, graph_replace
14+
from pytensor.scan import map as scan_map
15+
from pytensor.scan import scan
16+
from pytensor.tensor import TensorVariable
17+
18+
from pymc_experimental.distributions import DiscreteMarkovChain
19+
20+
21+
class MarginalRV(OpFromGraph, MeasurableOp):
22+
"""Base class for Marginalized RVs"""
23+
24+
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
25+
self.dims_connections = dims_connections
26+
super().__init__(*args, **kwargs)
27+
28+
@property
29+
def support_axes(self) -> tuple[tuple[int]]:
30+
"""Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable."""
31+
marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp
32+
support_axes_vars = []
33+
for dims_connection in self.dims_connections:
34+
ndim = len(dims_connection)
35+
marginalized_supp_axes = ndim - marginalized_ndim_supp
36+
support_axes_vars.append(
37+
tuple(
38+
-i
39+
for i, dim in enumerate(reversed(dims_connection), start=1)
40+
if (dim is None or dim > marginalized_supp_axes)
41+
)
42+
)
43+
return tuple(support_axes_vars)
44+
45+
46+
class MarginalFiniteDiscreteRV(MarginalRV):
47+
"""Base class for Marginalized Finite Discrete RVs"""
48+
49+
50+
class MarginalDiscreteMarkovChainRV(MarginalRV):
51+
"""Base class for Marginalized Discrete Markov Chain RVs"""
52+
53+
54+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
55+
op = rv.owner.op
56+
dist_params = rv.owner.op.dist_params(rv.owner)
57+
if isinstance(op, Bernoulli):
58+
return (0, 1)
59+
elif isinstance(op, Categorical):
60+
[p_param] = dist_params
61+
[p_param_length] = constant_fold([p_param.shape[-1]])
62+
return tuple(range(p_param_length))
63+
elif isinstance(op, DiscreteUniform):
64+
lower, upper = constant_fold(dist_params)
65+
return tuple(np.arange(lower, upper + 1))
66+
elif isinstance(op, DiscreteMarkovChain):
67+
P, *_ = dist_params
68+
return tuple(range(pt.get_vector_length(P[-1])))
69+
70+
raise NotImplementedError(f"Cannot compute domain for op {op}")
71+
72+
73+
def reduce_batch_dependent_logps(
74+
dependent_dims_connections: Sequence[tuple[int | None, ...]],
75+
dependent_ops: Sequence[Op],
76+
dependent_logps: Sequence[TensorVariable],
77+
) -> TensorVariable:
78+
"""Combine the logps of dependent RVs and align them with the marginalized logp.
79+
80+
This requires reducing extra batch dims and transposing when they are not aligned.
81+
82+
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
83+
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
84+
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
85+
86+
marginalize(idx)
87+
88+
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
89+
which tells us we need to reduce the last axis of dep1 logp and the first of dep2,
90+
as well as transpose the remaining axis of dep1 logp before adding the two elemwise.
91+
"""
92+
from pymc_experimental.model.marginal.graph_analysis import get_support_axes
93+
94+
reduced_logps = []
95+
for dependent_op, dependent_logp, dependent_dims_connection in zip(
96+
dependent_ops, dependent_logps, dependent_dims_connections
97+
):
98+
if dependent_logp.type.ndim > 0:
99+
# Find which support axis implied by the MarginalRV need to be reduced
100+
# Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs
101+
dep_supp_axes = get_support_axes(dependent_op)[0]
102+
103+
# Dependent RV support axes are already collapsed in the logp, so we ignore them
104+
supp_axes = [
105+
-i
106+
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
107+
if (dim is None and -i not in dep_supp_axes)
108+
]
109+
dependent_logp = dependent_logp.sum(supp_axes)
110+
111+
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
112+
dims_alignment = [dim for dim in dependent_dims_connection if dim is not None]
113+
dependent_logp = dependent_logp.transpose(*dims_alignment)
114+
115+
reduced_logps.append(dependent_logp)
116+
117+
reduced_logp = pt.add(*reduced_logps)
118+
return reduced_logp
119+
120+
121+
def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
122+
if logp.type.ndim > 0:
123+
# Transpose reduced logp into the direction of the first dependent RV
124+
dims_alignment = [dim for dim in dims if dim is not None]
125+
logp = logp.transpose(*dims_alignment)
126+
return logp
127+
128+
129+
dummy_zero = pt.constant(0, name="dummy_zero")
130+
131+
132+
@_logprob.register(MarginalFiniteDiscreteRV)
133+
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
134+
# Clone the inner RV graph of the Marginalized RV
135+
marginalized_rvs_node = op.make_node(*inputs)
136+
marginalized_rv, *inner_rvs = clone_replace(
137+
op.inner_outputs,
138+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
139+
)
140+
141+
# Obtain the joint_logp graph of the inner RV graph
142+
inner_rv_values = dict(zip(inner_rvs, values))
143+
marginalized_vv = marginalized_rv.clone()
144+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
145+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
146+
147+
# Reduce logp dimensions corresponding to broadcasted variables
148+
marginalized_logp = logps_dict.pop(marginalized_vv)
149+
joint_logp = marginalized_logp + reduce_batch_dependent_logps(
150+
dependent_dims_connections=op.dims_connections,
151+
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
152+
dependent_logps=[logps_dict[value] for value in values],
153+
)
154+
155+
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
156+
# each original dimension is independent so that it suffices to evaluate the graph
157+
# n times, once with each possible value of the marginalized RV replicated across
158+
# batched dimensions of the marginalized RV
159+
160+
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
161+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
162+
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
163+
marginalized_rv_domain_tensor = pt.moveaxis(
164+
pt.full(
165+
(*marginalized_rv_shape, len(marginalized_rv_domain)),
166+
marginalized_rv_domain,
167+
dtype=marginalized_rv.dtype,
168+
),
169+
-1,
170+
0,
171+
)
172+
173+
try:
174+
joint_logps = vectorize_graph(
175+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
176+
)
177+
except Exception:
178+
# Fallback to Scan
179+
def logp_fn(marginalized_rv_const, *non_sequences):
180+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
181+
182+
joint_logps, _ = scan_map(
183+
fn=logp_fn,
184+
sequences=marginalized_rv_domain_tensor,
185+
non_sequences=[*values, *inputs],
186+
mode=Mode().including("local_remove_check_parameter"),
187+
)
188+
189+
joint_logp = pt.logsumexp(joint_logps, axis=0)
190+
191+
# Align logp with non-collapsed batch dimensions of first RV
192+
joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp)
193+
194+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
195+
dummy_logps = (dummy_zero,) * (len(values) - 1)
196+
return joint_logp, *dummy_logps
197+
198+
199+
@_logprob.register(MarginalDiscreteMarkovChainRV)
200+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
201+
marginalized_rvs_node = op.make_node(*inputs)
202+
chain_rv, *dependent_rvs = clone_replace(
203+
op.inner_outputs,
204+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
205+
)
206+
207+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
208+
domain = pt.arange(P.shape[-1], dtype="int32")
209+
210+
# Construct logp in two steps
211+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
212+
213+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
214+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
215+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
216+
chain_value = chain_rv.clone()
217+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
218+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
219+
220+
# Reduce and add the batch dims beyond the chain dimension
221+
reduced_logp_emissions = reduce_batch_dependent_logps(
222+
dependent_dims_connections=op.dims_connections,
223+
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
224+
dependent_logps=[logp_emissions_dict[value] for value in values],
225+
)
226+
227+
# Add a batch dimension for the domain of the chain
228+
chain_shape = constant_fold(tuple(chain_rv.shape))
229+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
230+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
231+
232+
# Step 2: Compute the transition probabilities
233+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
234+
# We do it entirely in logs, though.
235+
236+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
237+
# under the initial distribution. This is robust to everything the user can throw at it.
238+
init_dist_value = init_dist_.type()
239+
logp_init_dist = logp(init_dist_, init_dist_value)
240+
# There is a degerate batch dim for lags=1 (the only supported case),
241+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
242+
batch_logp_init_dist = vectorize_graph(
243+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
244+
).squeeze(1)
245+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
246+
247+
def step_alpha(logp_emission, log_alpha, log_P):
248+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
249+
return logp_emission + step_log_prob
250+
251+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
252+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
253+
log_alpha_seq, _ = scan(
254+
step_alpha,
255+
non_sequences=[log_P],
256+
outputs_info=[log_alpha_init],
257+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
258+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
259+
)
260+
# Final logp is just the sum of the last scan state
261+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
262+
263+
# Align logp with non-collapsed batch dimensions of first RV
264+
remaining_dims_first_emission = list(op.dims_connections[0])
265+
# The last dim of chain_rv was removed when computing the logp
266+
remaining_dims_first_emission.remove(chain_rv.type.ndim - 1)
267+
joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp)
268+
269+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
270+
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
271+
dummy_logps = (dummy_zero,) * (len(values) - 1)
272+
return joint_logp, *dummy_logps

0 commit comments

Comments
 (0)