|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytensor.tensor as pt |
| 5 | + |
| 6 | +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform |
| 7 | +from pymc.logprob.abstract import MeasurableOp, _logprob |
| 8 | +from pymc.logprob.basic import conditional_logp, logp |
| 9 | +from pymc.pytensorf import constant_fold |
| 10 | +from pytensor.compile.builders import OpFromGraph |
| 11 | +from pytensor.compile.mode import Mode |
| 12 | +from pytensor.graph import Op, vectorize_graph |
| 13 | +from pytensor.graph.replace import clone_replace, graph_replace |
| 14 | +from pytensor.scan import map as scan_map |
| 15 | +from pytensor.scan import scan |
| 16 | +from pytensor.tensor import TensorVariable |
| 17 | + |
| 18 | +from pymc_experimental.distributions import DiscreteMarkovChain |
| 19 | + |
| 20 | + |
| 21 | +class MarginalRV(OpFromGraph, MeasurableOp): |
| 22 | + """Base class for Marginalized RVs""" |
| 23 | + |
| 24 | + def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None: |
| 25 | + self.dims_connections = dims_connections |
| 26 | + super().__init__(*args, **kwargs) |
| 27 | + |
| 28 | + @property |
| 29 | + def support_axes(self) -> tuple[tuple[int]]: |
| 30 | + """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" |
| 31 | + marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp |
| 32 | + support_axes_vars = [] |
| 33 | + for dims_connection in self.dims_connections: |
| 34 | + ndim = len(dims_connection) |
| 35 | + marginalized_supp_axes = ndim - marginalized_ndim_supp |
| 36 | + support_axes_vars.append( |
| 37 | + tuple( |
| 38 | + -i |
| 39 | + for i, dim in enumerate(reversed(dims_connection), start=1) |
| 40 | + if (dim is None or dim > marginalized_supp_axes) |
| 41 | + ) |
| 42 | + ) |
| 43 | + return tuple(support_axes_vars) |
| 44 | + |
| 45 | + |
| 46 | +class MarginalFiniteDiscreteRV(MarginalRV): |
| 47 | + """Base class for Marginalized Finite Discrete RVs""" |
| 48 | + |
| 49 | + |
| 50 | +class MarginalDiscreteMarkovChainRV(MarginalRV): |
| 51 | + """Base class for Marginalized Discrete Markov Chain RVs""" |
| 52 | + |
| 53 | + |
| 54 | +def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: |
| 55 | + op = rv.owner.op |
| 56 | + dist_params = rv.owner.op.dist_params(rv.owner) |
| 57 | + if isinstance(op, Bernoulli): |
| 58 | + return (0, 1) |
| 59 | + elif isinstance(op, Categorical): |
| 60 | + [p_param] = dist_params |
| 61 | + [p_param_length] = constant_fold([p_param.shape[-1]]) |
| 62 | + return tuple(range(p_param_length)) |
| 63 | + elif isinstance(op, DiscreteUniform): |
| 64 | + lower, upper = constant_fold(dist_params) |
| 65 | + return tuple(np.arange(lower, upper + 1)) |
| 66 | + elif isinstance(op, DiscreteMarkovChain): |
| 67 | + P, *_ = dist_params |
| 68 | + return tuple(range(pt.get_vector_length(P[-1]))) |
| 69 | + |
| 70 | + raise NotImplementedError(f"Cannot compute domain for op {op}") |
| 71 | + |
| 72 | + |
| 73 | +def reduce_batch_dependent_logps( |
| 74 | + dependent_dims_connections: Sequence[tuple[int | None, ...]], |
| 75 | + dependent_ops: Sequence[Op], |
| 76 | + dependent_logps: Sequence[TensorVariable], |
| 77 | +) -> TensorVariable: |
| 78 | + """Combine the logps of dependent RVs and align them with the marginalized logp. |
| 79 | +
|
| 80 | + This requires reducing extra batch dims and transposing when they are not aligned. |
| 81 | +
|
| 82 | + idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1 |
| 83 | + pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5)) |
| 84 | + pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3)) |
| 85 | +
|
| 86 | + marginalize(idx) |
| 87 | +
|
| 88 | + The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)] |
| 89 | + which tells us we need to reduce the last axis of dep1 logp and the first of dep2, |
| 90 | + as well as transpose the remaining axis of dep1 logp before adding the two elemwise. |
| 91 | + """ |
| 92 | + from pymc_experimental.model.marginal.graph_analysis import get_support_axes |
| 93 | + |
| 94 | + reduced_logps = [] |
| 95 | + for dependent_op, dependent_logp, dependent_dims_connection in zip( |
| 96 | + dependent_ops, dependent_logps, dependent_dims_connections |
| 97 | + ): |
| 98 | + if dependent_logp.type.ndim > 0: |
| 99 | + # Find which support axis implied by the MarginalRV need to be reduced |
| 100 | + # Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs |
| 101 | + dep_supp_axes = get_support_axes(dependent_op)[0] |
| 102 | + |
| 103 | + # Dependent RV support axes are already collapsed in the logp, so we ignore them |
| 104 | + supp_axes = [ |
| 105 | + -i |
| 106 | + for i, dim in enumerate(reversed(dependent_dims_connection), start=1) |
| 107 | + if (dim is None and -i not in dep_supp_axes) |
| 108 | + ] |
| 109 | + dependent_logp = dependent_logp.sum(supp_axes) |
| 110 | + |
| 111 | + # Finally, we need to align the dependent logp batch dimensions with the marginalized logp |
| 112 | + dims_alignment = [dim for dim in dependent_dims_connection if dim is not None] |
| 113 | + dependent_logp = dependent_logp.transpose(*dims_alignment) |
| 114 | + |
| 115 | + reduced_logps.append(dependent_logp) |
| 116 | + |
| 117 | + reduced_logp = pt.add(*reduced_logps) |
| 118 | + return reduced_logp |
| 119 | + |
| 120 | + |
| 121 | +def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: |
| 122 | + if logp.type.ndim > 0: |
| 123 | + # Transpose reduced logp into the direction of the first dependent RV |
| 124 | + dims_alignment = [dim for dim in dims if dim is not None] |
| 125 | + logp = logp.transpose(*dims_alignment) |
| 126 | + return logp |
| 127 | + |
| 128 | + |
| 129 | +dummy_zero = pt.constant(0, name="dummy_zero") |
| 130 | + |
| 131 | + |
| 132 | +@_logprob.register(MarginalFiniteDiscreteRV) |
| 133 | +def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): |
| 134 | + # Clone the inner RV graph of the Marginalized RV |
| 135 | + marginalized_rvs_node = op.make_node(*inputs) |
| 136 | + marginalized_rv, *inner_rvs = clone_replace( |
| 137 | + op.inner_outputs, |
| 138 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 139 | + ) |
| 140 | + |
| 141 | + # Obtain the joint_logp graph of the inner RV graph |
| 142 | + inner_rv_values = dict(zip(inner_rvs, values)) |
| 143 | + marginalized_vv = marginalized_rv.clone() |
| 144 | + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} |
| 145 | + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) |
| 146 | + |
| 147 | + # Reduce logp dimensions corresponding to broadcasted variables |
| 148 | + marginalized_logp = logps_dict.pop(marginalized_vv) |
| 149 | + joint_logp = marginalized_logp + reduce_batch_dependent_logps( |
| 150 | + dependent_dims_connections=op.dims_connections, |
| 151 | + dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs], |
| 152 | + dependent_logps=[logps_dict[value] for value in values], |
| 153 | + ) |
| 154 | + |
| 155 | + # Compute the joint_logp for all possible n values of the marginalized RV. We assume |
| 156 | + # each original dimension is independent so that it suffices to evaluate the graph |
| 157 | + # n times, once with each possible value of the marginalized RV replicated across |
| 158 | + # batched dimensions of the marginalized RV |
| 159 | + |
| 160 | + # PyMC does not allow RVs in the logp graph, even if we are just using the shape |
| 161 | + marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) |
| 162 | + marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) |
| 163 | + marginalized_rv_domain_tensor = pt.moveaxis( |
| 164 | + pt.full( |
| 165 | + (*marginalized_rv_shape, len(marginalized_rv_domain)), |
| 166 | + marginalized_rv_domain, |
| 167 | + dtype=marginalized_rv.dtype, |
| 168 | + ), |
| 169 | + -1, |
| 170 | + 0, |
| 171 | + ) |
| 172 | + |
| 173 | + try: |
| 174 | + joint_logps = vectorize_graph( |
| 175 | + joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor} |
| 176 | + ) |
| 177 | + except Exception: |
| 178 | + # Fallback to Scan |
| 179 | + def logp_fn(marginalized_rv_const, *non_sequences): |
| 180 | + return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const}) |
| 181 | + |
| 182 | + joint_logps, _ = scan_map( |
| 183 | + fn=logp_fn, |
| 184 | + sequences=marginalized_rv_domain_tensor, |
| 185 | + non_sequences=[*values, *inputs], |
| 186 | + mode=Mode().including("local_remove_check_parameter"), |
| 187 | + ) |
| 188 | + |
| 189 | + joint_logp = pt.logsumexp(joint_logps, axis=0) |
| 190 | + |
| 191 | + # Align logp with non-collapsed batch dimensions of first RV |
| 192 | + joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp) |
| 193 | + |
| 194 | + # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise |
| 195 | + dummy_logps = (dummy_zero,) * (len(values) - 1) |
| 196 | + return joint_logp, *dummy_logps |
| 197 | + |
| 198 | + |
| 199 | +@_logprob.register(MarginalDiscreteMarkovChainRV) |
| 200 | +def marginal_hmm_logp(op, values, *inputs, **kwargs): |
| 201 | + marginalized_rvs_node = op.make_node(*inputs) |
| 202 | + chain_rv, *dependent_rvs = clone_replace( |
| 203 | + op.inner_outputs, |
| 204 | + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, |
| 205 | + ) |
| 206 | + |
| 207 | + P, n_steps_, init_dist_, rng = chain_rv.owner.inputs |
| 208 | + domain = pt.arange(P.shape[-1], dtype="int32") |
| 209 | + |
| 210 | + # Construct logp in two steps |
| 211 | + # Step 1: Compute the probability of the data ("emissions") under every possible state (vec_logp_emission) |
| 212 | + |
| 213 | + # First we need to vectorize the conditional logp graph of the data, in case there are batch dimensions floating |
| 214 | + # around. To do this, we need to break the dependency between chain and the init_dist_ random variable. Otherwise, |
| 215 | + # PyMC will detect a random variable in the logp graph (init_dist_), that isn't relevant at this step. |
| 216 | + chain_value = chain_rv.clone() |
| 217 | + dependent_rvs = clone_replace(dependent_rvs, {chain_rv: chain_value}) |
| 218 | + logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) |
| 219 | + |
| 220 | + # Reduce and add the batch dims beyond the chain dimension |
| 221 | + reduced_logp_emissions = reduce_batch_dependent_logps( |
| 222 | + dependent_dims_connections=op.dims_connections, |
| 223 | + dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs], |
| 224 | + dependent_logps=[logp_emissions_dict[value] for value in values], |
| 225 | + ) |
| 226 | + |
| 227 | + # Add a batch dimension for the domain of the chain |
| 228 | + chain_shape = constant_fold(tuple(chain_rv.shape)) |
| 229 | + batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0) |
| 230 | + batch_logp_emissions = vectorize_graph(reduced_logp_emissions, {chain_value: batch_chain_value}) |
| 231 | + |
| 232 | + # Step 2: Compute the transition probabilities |
| 233 | + # This is the "forward algorithm", alpha_t = p(y | s_t) * sum_{s_{t-1}}(p(s_t | s_{t-1}) * alpha_{t-1}) |
| 234 | + # We do it entirely in logs, though. |
| 235 | + |
| 236 | + # To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) |
| 237 | + # under the initial distribution. This is robust to everything the user can throw at it. |
| 238 | + init_dist_value = init_dist_.type() |
| 239 | + logp_init_dist = logp(init_dist_, init_dist_value) |
| 240 | + # There is a degerate batch dim for lags=1 (the only supported case), |
| 241 | + # that we have to work around, by expanding the batch value and then squeezing it out of the logp |
| 242 | + batch_logp_init_dist = vectorize_graph( |
| 243 | + logp_init_dist, {init_dist_value: batch_chain_value[:, None, ..., 0]} |
| 244 | + ).squeeze(1) |
| 245 | + log_alpha_init = batch_logp_init_dist + batch_logp_emissions[..., 0] |
| 246 | + |
| 247 | + def step_alpha(logp_emission, log_alpha, log_P): |
| 248 | + step_log_prob = pt.logsumexp(log_alpha[:, None] + log_P, axis=0) |
| 249 | + return logp_emission + step_log_prob |
| 250 | + |
| 251 | + P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2) |
| 252 | + log_P = pt.shape_padright(pt.log(P), P_bcast_dims) |
| 253 | + log_alpha_seq, _ = scan( |
| 254 | + step_alpha, |
| 255 | + non_sequences=[log_P], |
| 256 | + outputs_info=[log_alpha_init], |
| 257 | + # Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value |
| 258 | + sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0), |
| 259 | + ) |
| 260 | + # Final logp is just the sum of the last scan state |
| 261 | + joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) |
| 262 | + |
| 263 | + # Align logp with non-collapsed batch dimensions of first RV |
| 264 | + remaining_dims_first_emission = list(op.dims_connections[0]) |
| 265 | + # The last dim of chain_rv was removed when computing the logp |
| 266 | + remaining_dims_first_emission.remove(chain_rv.type.ndim - 1) |
| 267 | + joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp) |
| 268 | + |
| 269 | + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first |
| 270 | + # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream. |
| 271 | + dummy_logps = (dummy_zero,) * (len(values) - 1) |
| 272 | + return joint_logp, *dummy_logps |
0 commit comments