Skip to content

Commit 1b95c8d

Browse files
committed
Support more kinds of marginalization via dim analysis
This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized.
1 parent 69cb216 commit 1b95c8d

File tree

9 files changed

+976
-484
lines changed

9 files changed

+976
-484
lines changed

pymc_experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_experimental import distributions, gp, statespace, utils
1717
from pymc_experimental.inference.fit import fit
18-
from pymc_experimental.model.marginal_model import MarginalModel
18+
from pymc_experimental.model.marginal.marginal_model import MarginalModel
1919
from pymc_experimental.model.model_api import as_model
2020
from pymc_experimental.version import __version__
2121

pymc_experimental/model/marginal/__init__.py

Whitespace-only changes.
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from collections.abc import Sequence
2+
3+
import numpy as np
4+
import pytensor.tensor as pt
5+
6+
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
7+
from pymc.logprob.abstract import _logprob
8+
from pymc.logprob.basic import conditional_logp, logp
9+
from pymc.pytensorf import constant_fold
10+
from pytensor.compile.mode import Mode
11+
from pytensor.graph import vectorize_graph
12+
from pytensor.graph.replace import clone_replace, graph_replace
13+
from pytensor.scan import map as scan_map
14+
from pytensor.scan import scan
15+
from pytensor.tensor import TensorType, TensorVariable
16+
17+
from pymc_experimental.distributions import DiscreteMarkovChain
18+
19+
20+
class MarginalRV(SymbolicRandomVariable):
21+
"""Base class for Marginalized RVs"""
22+
23+
24+
class FiniteDiscreteMarginalRV(MarginalRV):
25+
"""Base class for Finite Discrete Marginalized RVs"""
26+
27+
28+
class DiscreteMarginalMarkovChainRV(MarginalRV):
29+
"""Base class for Discrete Marginal Markov Chain RVs"""
30+
31+
32+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
33+
op = rv.owner.op
34+
dist_params = rv.owner.op.dist_params(rv.owner)
35+
if isinstance(op, Bernoulli):
36+
return (0, 1)
37+
elif isinstance(op, Categorical):
38+
[p_param] = dist_params
39+
[p_param_length] = constant_fold([p_param.shape[-1]])
40+
return tuple(range(p_param_length))
41+
elif isinstance(op, DiscreteUniform):
42+
lower, upper = constant_fold(dist_params)
43+
return tuple(np.arange(lower, upper + 1))
44+
elif isinstance(op, DiscreteMarkovChain):
45+
P, *_ = dist_params
46+
return tuple(range(pt.get_vector_length(P[-1])))
47+
48+
raise NotImplementedError(f"Cannot compute domain for op {op}")
49+
50+
51+
def _add_reduce_batch_dependent_logps(
52+
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
53+
):
54+
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
55+
56+
mbcast = marginalized_type.broadcastable
57+
reduced_logps = []
58+
for dependent_logp in dependent_logps:
59+
dbcast = dependent_logp.type.broadcastable
60+
dim_diff = len(dbcast) - len(mbcast)
61+
mbcast_aligned = mbcast + (True,) * dim_diff
62+
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
63+
reduced_logps.append(dependent_logp.sum(vbcast_axis))
64+
return pt.add(*reduced_logps)
65+
66+
67+
@_logprob.register(FiniteDiscreteMarginalRV)
68+
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
69+
# Clone the inner RV graph of the Marginalized RV
70+
marginalized_rvs_node = op.make_node(*inputs)
71+
marginalized_rv, *inner_rvs = clone_replace(
72+
op.inner_outputs,
73+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
74+
)
75+
76+
# Obtain the joint_logp graph of the inner RV graph
77+
inner_rv_values = dict(zip(inner_rvs, values))
78+
marginalized_vv = marginalized_rv.clone()
79+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
80+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
81+
82+
# Reduce logp dimensions corresponding to broadcasted variables
83+
marginalized_logp = logps_dict.pop(marginalized_vv)
84+
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
85+
marginalized_rv.type, logps_dict.values()
86+
)
87+
88+
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
89+
# each original dimension is independent so that it suffices to evaluate the graph
90+
# n times, once with each possible value of the marginalized RV replicated across
91+
# batched dimensions of the marginalized RV
92+
93+
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
94+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
95+
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
96+
marginalized_rv_domain_tensor = pt.moveaxis(
97+
pt.full(
98+
(*marginalized_rv_shape, len(marginalized_rv_domain)),
99+
marginalized_rv_domain,
100+
dtype=marginalized_rv.dtype,
101+
),
102+
-1,
103+
0,
104+
)
105+
106+
try:
107+
joint_logps = vectorize_graph(
108+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
109+
)
110+
except Exception:
111+
# Fallback to Scan
112+
def logp_fn(marginalized_rv_const, *non_sequences):
113+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
114+
115+
joint_logps, _ = scan_map(
116+
fn=logp_fn,
117+
sequences=marginalized_rv_domain_tensor,
118+
non_sequences=[*values, *inputs],
119+
mode=Mode().including("local_remove_check_parameter"),
120+
)
121+
122+
joint_logps = pt.logsumexp(joint_logps, axis=0)
123+
124+
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
125+
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
126+
127+
128+
@_logprob.register(DiscreteMarginalMarkovChainRV)
129+
def marginal_hmm_logp(op, values, *inputs, **kwargs):
130+
marginalized_rvs_node = op.make_node(*inputs)
131+
inner_rvs = clone_replace(
132+
op.inner_outputs,
133+
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
134+
)
135+
136+
chain_rv, *dependent_rvs = inner_rvs
137+
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
138+
domain = pt.arange(P.shape[-1], dtype="int32")
139+
140+
# Construct logp in two steps
141+
# Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission)
142+
143+
# First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating
144+
# around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise,
145+
# PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step.
146+
chain_value = chain_rv.clone()
147+
dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value})
148+
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
149+
150+
# Reduce and add the batch dims beyond the chain dimension
151+
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
152+
chain_rv.type, logp_emissions_dict.values()
153+
)
154+
155+
# Add a batch dimension for the domain of the chain
156+
chain_shape = constant_fold(tuple(chain_rv.shape))
157+
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
158+
batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value})
159+
160+
# Step 2: Compute the transition probabilities
161+
# This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1})
162+
# We do it entirely in logs, though.
163+
164+
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states)
165+
# under the initial distribution. This is robust to everything the user can throw at it.
166+
init_dist_value = init_dist_.type()
167+
logp_init_dist = logp(init_dist_, init_dist_value)
168+
# There is a degerate batch dim for lags=1 (the only supported case),
169+
# that we have to work around, by expanding the batch value and then squeezing it out of the logp
170+
batch_logp_init_dist = vectorize_graph(
171+
logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]}
172+
).squeeze(1)
173+
log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0]
174+
175+
def step_alpha(logp_emission, log_alpha, log_P):
176+
step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0)
177+
return logp_emission + step_log_prob
178+
179+
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
180+
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
181+
log_alpha_seq, _ = scan(
182+
step_alpha,
183+
non_sequences=[log_P],
184+
outputs_info=[log_alpha_init],
185+
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
186+
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
187+
)
188+
# Final logp is just the sum of the last scan state
189+
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
190+
191+
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
192+
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
193+
dummy_logps = (pt.constant(0),) * (len(values) - 1)
194+
return joint_logp, *dummy_logps

0 commit comments

Comments
 (0)