Skip to content

Commit 737a38f

Browse files
committed
.WIP
1 parent da6e49d commit 737a38f

File tree

5 files changed

+148
-83
lines changed

5 files changed

+148
-83
lines changed

pymc_experimental/model/marginal/distributions.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import pytensor.tensor as pt
55

66
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
7-
from pymc.logprob.abstract import _logprob
7+
from pymc.logprob.abstract import _logprob, MeasurableOp
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10+
from pytensor.compile.builders import OpFromGraph
1011
from pytensor.compile.mode import Mode
11-
from pytensor.graph import vectorize_graph
12+
from pytensor.graph import vectorize_graph, Op
1213
from pytensor.graph.replace import clone_replace, graph_replace
1314
from pytensor.scan import map as scan_map
1415
from pytensor.scan import scan
@@ -17,16 +18,20 @@
1718
from pymc_experimental.distributions import DiscreteMarkovChain
1819

1920

20-
class MarginalRV(SymbolicRandomVariable):
21+
class MarginalRV(OpFromGraph, MeasurableOp):
2122
"""Base class for Marginalized RVs"""
2223

24+
def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None:
25+
self.dims_connections = dims_connections
26+
super().__init__(*args, **kwargs)
2327

24-
class FiniteDiscreteMarginalRV(MarginalRV):
25-
"""Base class for Finite Discrete Marginalized RVs"""
2628

29+
class MarginalFiniteDiscreteRV(MarginalRV):
30+
"""Base class for Marginalized Finite Discrete RVs"""
2731

28-
class DiscreteMarginalMarkovChainRV(MarginalRV):
29-
"""Base class for Discrete Marginal Markov Chain RVs"""
32+
33+
class MarginalDiscreteMarkovChainRV(MarginalRV):
34+
"""Base class for Marginalized Discrete Markov Chain RVs"""
3035

3136

3237
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
@@ -48,24 +53,69 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
4853
raise NotImplementedError(f"Cannot compute domain for op {op}")
4954

5055

51-
def _add_reduce_batch_dependent_logps(
52-
marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable]
53-
):
54-
"""Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`."""
56+
def _reduce_batch_dependent_logps(
57+
marginalized_op: MarginalRV,
58+
marginalized_logp: TensorVariable,
59+
dependent_ops: Sequence[Op],
60+
dependent_logps: Sequence[TensorVariable],
61+
) -> TensorVariable:
62+
"""Combine the logps of dependent RVs with the marginalized logp.
63+
64+
This requires reducing extra batch dims and transposing when they are not aligned.
65+
66+
idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1
67+
pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5))
68+
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
69+
70+
marginalize(idx)
71+
dims_connections = [(1, 0, None), (None, 0, 1)]
72+
"""
73+
74+
dims_connections = marginalized_op.dims_connections
75+
76+
reduced_logps = [marginalized_logp]
77+
for dependent_op, dependent_logp, dims_connection in zip(dependent_ops, dependent_logps, dims_connections):
78+
if dependent_logp.type.ndim > 0:
79+
# Find which support axis implied by the MarginalRV need to be reduced
80+
# Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs
81+
if isinstance(dependent_op, MarginalRV):
82+
dep_dims_connection = dependent_op.dims_connections[0]
83+
dep_supp_axes = {-i for i, dim in enumerate(reversed(dep_dims_connection), start=1) if dim == ()}
84+
else:
85+
# For vanilla RVs, the support axes are the last ndim_supp
86+
dep_supp_axes = set(range(-dependent_op.ndim_supp, 0))
87+
88+
# Dependent RV support axes are already collapsed in the logp, so we ignore them
89+
supp_axes = [
90+
-i
91+
for i, dim in enumerate(reversed(dims_connection), start=1)
92+
if (dim == () and -i not in dep_supp_axes)
93+
]
94+
95+
dependent_logp = dependent_logp.sum(supp_axes)
96+
assert dependent_logp.type.ndim == marginalized_logp.type.ndim
5597

56-
mbcast = marginalized_type.broadcastable
57-
reduced_logps = []
58-
for dependent_logp in dependent_logps:
59-
dbcast = dependent_logp.type.broadcastable
60-
dim_diff = len(dbcast) - len(mbcast)
61-
mbcast_aligned = mbcast + (True,) * dim_diff
62-
vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v]
63-
reduced_logps.append(dependent_logp.sum(vbcast_axis))
64-
return pt.add(*reduced_logps)
98+
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
99+
dims_alignment = [dim[0] for dim in dims_connection if dim != ()]
100+
dependent_logp = dependent_logp.transpose(*dims_alignment)
65101

102+
reduced_logps.append(dependent_logp)
66103

67-
@_logprob.register(FiniteDiscreteMarginalRV)
68-
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
104+
reduced_logp = pt.add(*reduced_logps)
105+
106+
if reduced_logp.type.ndim > 0:
107+
# Transpose reduced logp into the direction of the first dependent RV
108+
first_dims_alignment = [dim[0] for dim in dims_connections[0] if dim != ()]
109+
reduced_logp = reduced_logp.transpose(*first_dims_alignment)
110+
111+
return reduced_logp
112+
113+
114+
dummy_zero = pt.constant(0, name="dummy_zero")
115+
116+
117+
@_logprob.register(MarginalFiniteDiscreteRV)
118+
def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs):
69119
# Clone the inner RV graph of the Marginalized RV
70120
marginalized_rvs_node = op.make_node(*inputs)
71121
marginalized_rv, *inner_rvs = clone_replace(
@@ -81,8 +131,11 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
81131

82132
# Reduce logp dimensions corresponding to broadcasted variables
83133
marginalized_logp = logps_dict.pop(marginalized_vv)
84-
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
85-
marginalized_rv.type, logps_dict.values()
134+
joint_logp = _reduce_batch_dependent_logps(
135+
marginalized_op=op,
136+
marginalized_logp=marginalized_logp,
137+
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
138+
dependent_logps=[logps_dict[value] for value in values],
86139
)
87140

88141
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
@@ -119,21 +172,20 @@ def logp_fn(marginalized_rv_const, *non_sequences):
119172
mode=Mode().including("local_remove_check_parameter"),
120173
)
121174

122-
joint_logps = pt.logsumexp(joint_logps, axis=0)
175+
joint_logp = pt.logsumexp(joint_logps, axis=0)
123176

124177
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
125-
return joint_logps, *(pt.constant(0),) * (len(values) - 1)
178+
return joint_logp, *((dummy_zero,) * (len(values) - 1))
126179

127180

128-
@_logprob.register(DiscreteMarginalMarkovChainRV)
181+
@_logprob.register(MarginalDiscreteMarkovChainRV)
129182
def marginal_hmm_logp(op, values, *inputs, **kwargs):
130183
marginalized_rvs_node = op.make_node(*inputs)
131-
inner_rvs = clone_replace(
184+
chain_rv, *dependent_rvs = clone_replace(
132185
op.inner_outputs,
133186
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
134187
)
135188

136-
chain_rv, *dependent_rvs = inner_rvs
137189
P, n_steps_, init_dist_, rng = chain_rv.owner.inputs
138190
domain = pt.arange(P.shape[-1], dtype="int32")
139191

@@ -149,9 +201,11 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
149201

150202
# Reduce and add the batch dims beyond the chain dimension
151203
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
204+
init_logp,
152205
chain_rv.type, logp_emissions_dict.values()
153206
)
154207

208+
155209
# Add a batch dimension for the domain of the chain
156210
chain_shape = constant_fold(tuple(chain_rv.shape))
157211
batch_chain_value = pt.moveaxis(pt.full((*chain_shape, domain.size), domain), -1, 0)
@@ -188,7 +242,9 @@ def step_alpha(logp_emission, log_alpha, log_P):
188242
# Final logp is just the sum of the last scan state
189243
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
190244

245+
# TODO: Transpose into shape of first emission
246+
191247
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
192-
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
193-
dummy_logps = (pt.constant(0),) * (len(values) - 1)
248+
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
249+
dummy_logps = (dummy_zero) * (len(values) - 1)
194250
return joint_logp, *dummy_logps

pymc_experimental/model/marginal/graph_analysis.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
1616
from pytensor.tensor.type_other import NoneTypeT
1717

18+
from pymc_experimental.model.marginal.distributions import MarginalRV
19+
1820

1921
def static_shape_ancestors(vars):
2022
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
@@ -101,6 +103,12 @@ def _broadcast_dims(
101103
output_dims = tuple(
102104
tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)
103105
)
106+
if any(len(output_dim) > 1 for output_dim in output_dims):
107+
raise ValueError("Different known dimensions mixed via broadcasting")
108+
109+
if len(set(output_dim[0] for output_dim in output_dims if output_dim != ())) < len([output_dim for output_dim in output_dims if output_dim != ()]):
110+
raise ValueError("Same dimension used in different axis after broadcasting")
111+
104112
return output_dims
105113

106114

@@ -111,6 +119,9 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
111119
for node in io_toposort(input_vars, output_vars):
112120
inputs_dims = [var_dims.get(inp, ()) for inp in node.inputs]
113121

122+
# f(marginalized_rv, *other_junk) -> dep_rv1, dep_rv2
123+
# g(marginalized_rv.ravel()[i], *other_junk) -> dep_rv1.ravel()[?], derp_rv2.ravel()[?]
124+
114125
if not any(inputs_dims):
115126
# None of the inputs are related to the batch_axes of the marginalized_rv
116127
continue
@@ -122,10 +133,23 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
122133
)
123134
var_dims[node.outputs[0]] = output_dims
124135

136+
elif isinstance(node.op, MarginalRV):
137+
138+
inner_var_dims = {
139+
inner_inp: input_dims
140+
for inner_inp, input_dims in zip(node.op.inner_inputs, inputs_dims)
141+
}
142+
inner_var_dims = _subgraph_dim_connection(
143+
inner_var_dims, node.op.inner_inputs, node.op.inner_outputs
144+
)
145+
for out, inner_out in zip(node.outputs, node.op.inner_outputs):
146+
# FIXME: If the known output_dim belongs to the supp_axis of the inner MarginalizedRV, this should raise
147+
# Add test in test_graph_analysis
148+
if inner_out in inner_var_dims:
149+
var_dims[out] = inner_var_dims[inner_out]
150+
125151
elif (
126-
isinstance(node.op, CustomSymbolicDistRV)
127-
or isinstance(node.op, SymbolicRandomVariable)
128-
and node.op.extended_signature is None
152+
isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None
129153
):
130154
# SymbolicRandomVariables without signature are a wild-card, so we need to introspect the inner graph.
131155
# MarginalRVs are such a case!
@@ -274,6 +298,13 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
274298

275299
var_dims[node.outputs[0]] = output_dims
276300

301+
# categorical(p=dimshuffle(matrix))
302+
# (0, 1) -> (1, 0) -> (1,)
303+
# ((0,), (1,)) -> (0, 1)
304+
# (a,b),(c,d)->(a,b,c,d)
305+
# (),()->() -> (a, b),(None, None),(a * None,b * None)
306+
# (a, b),(c, d),(a * c, b * d)
307+
277308
else:
278309
raise NotImplementedError(f"Marginalization through operation {node} not supported.")
279310

pymc_experimental/model/marginal/marginal_model.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
from pymc_experimental.distributions import DiscreteMarkovChain
2626
from pymc_experimental.model.marginal.distributions import (
27-
DiscreteMarginalMarkovChainRV,
28-
FiniteDiscreteMarginalRV,
29-
_add_reduce_batch_dependent_logps,
27+
MarginalDiscreteMarkovChainRV,
28+
MarginalFiniteDiscreteRV,
29+
_reduce_batch_dependent_logps,
3030
get_domain_of_finite_discrete_rv,
3131
)
3232
from pymc_experimental.model.marginal.graph_analysis import (
@@ -431,7 +431,7 @@ def transform_input(inputs):
431431

432432
# Handle batch dims for marginalized value and its dependent RVs
433433
marginalized_logp, *dependent_logps = joint_logps
434-
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
434+
joint_logp = marginalized_logp + _reduce_batch_dependent_logps(
435435
marginalized_rv.type, dependent_logps
436436
)
437437

@@ -556,12 +556,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
556556
if rv is not rv_to_marginalize
557557
]
558558

559-
dependent_rvs_ndim_supp = {dependent_rv.owner.op.ndim_supp for dependent_rv in dependent_rvs}
560-
if len(dependent_rvs_ndim_supp) > 1:
561-
raise NotImplementedError("All dependent RVs must have the same support dimensionality")
562-
563-
[dependent_rv_ndim_supp] = dependent_rvs_ndim_supp
564-
559+
# Todo: back to broadcastable
565560
if rv_to_marginalize.type.ndim > 0:
566561
# If the marginalized RV has multiple dimensions, check that graph between
567562
# marginalized RV and dependent RVs does not mix information from batch dimensions
@@ -573,33 +568,12 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
573568
except ValueError as e:
574569
# For the perspective of the user this is a NotImplementedError
575570
raise NotImplementedError(
576-
"The graph between the marginalized and dependent RVs cannot be marginalized"
571+
"The graph between the marginalized and dependent RVs cannot be marginalized."
577572
) from e
578573

579-
if any(
580-
len(dim) > 1
581-
for rv_dim_connections in dependent_rvs_dim_connections
582-
for dim in rv_dim_connections
583-
):
584-
raise NotImplementedError("Multiple dimensions are mixed")
585-
586-
# We further check that batch dimensions of the marginalized RVs are aligned with those of the dependent RV
587-
marginal_ndim = rv_to_marginalize.type.ndim
588-
marginal_batch_dims = tuple((i,) for i in range(marginal_ndim))
589-
for dependent_rv, dependent_rv_batch_dims in zip(
590-
dependent_rvs, dependent_rvs_dim_connections
591-
):
592-
extra_batch_ndim = dependent_rv.type.ndim + dependent_rv_ndim_supp - marginal_ndim
593-
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim)
594-
if dependent_rv_batch_dims != valid_dependent_batch_dims:
595-
raise NotImplementedError(
596-
f"Link between dimensions of marginalized and dependent RVs not supported: {dependent_rv_batch_dims} != {valid_dependent_batch_dims}"
597-
)
574+
else:
575+
dependent_rvs_dim_connections = tuple(((),) * dependent_rv.type.ndim for dependent_rv in dependent_rvs)
598576

599-
ndim_supp = max(
600-
(dependent_rv.type.ndim + dependent_rv_ndim_supp - rv_to_marginalize.type.ndim)
601-
for dependent_rv in dependent_rvs
602-
)
603577

604578
input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors)))
605579
output_rvs = [rv_to_marginalize, *dependent_rvs]
@@ -608,14 +582,14 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
608582
inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)
609583

610584
if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
611-
marginalize_constructor = DiscreteMarginalMarkovChainRV
585+
marginalize_constructor = MarginalDiscreteMarkovChainRV
612586
else:
613-
marginalize_constructor = FiniteDiscreteMarginalRV
587+
marginalize_constructor = MarginalFiniteDiscreteRV
614588

615589
marginalization_op = marginalize_constructor(
616590
inputs=inputs,
617591
outputs=output_rvs, # TODO: Add RNG updates to outputs
618-
ndim_supp=ndim_supp,
592+
dims_connections=dependent_rvs_dim_connections,
619593
)
620594
new_output_rvs = marginalization_op(*inputs)
621595
fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))

tests/model/marginal/test_distributions.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pymc_experimental import MarginalModel
99
from pymc_experimental.distributions import DiscreteMarkovChain
1010

11-
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV
11+
from pymc_experimental.model.marginal.distributions import MarginalFiniteDiscreteRV
1212

1313

1414
def test_marginalized_bernoulli_logp():
@@ -17,13 +17,10 @@ def test_marginalized_bernoulli_logp():
1717

1818
idx = pm.Bernoulli.dist(0.7, name="idx")
1919
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
20-
marginal_rv_node = FiniteDiscreteMarginalRV(
20+
marginal_rv_node = MarginalFiniteDiscreteRV(
2121
[mu],
2222
[idx, y],
23-
ndim_supp=0,
24-
n_updates=0,
25-
# Ignore the fact we didn't specify shared RNG input/outputs for idx,y
26-
strict=False,
23+
dims_connections=(((),),),
2724
)(mu)[0].owner
2825

2926
y_vv = y.clone()

0 commit comments

Comments
 (0)