Skip to content

Commit 02b4eb2

Browse files
committed
.WIP
1 parent 97998af commit 02b4eb2

File tree

6 files changed

+110
-78
lines changed

6 files changed

+110
-78
lines changed

pymc_experimental/model/marginal/distributions.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
5454

5555

5656
def _reduce_batch_dependent_logps(
57-
marginalized_op: MarginalRV,
58-
marginalized_logp: TensorVariable,
57+
dependent_dims_connections: Sequence[tuple[int | None, ...]],
5958
dependent_ops: Sequence[Op],
6059
dependent_logps: Sequence[TensorVariable],
6160
) -> TensorVariable:
62-
"""Combine the logps of dependent RVs with the marginalized logp.
61+
"""Combine the logps of dependent RVs and align them with the marginalized logp.
6362
6463
This requires reducing extra batch dims and transposing when they are not aligned.
6564
@@ -68,13 +67,14 @@ def _reduce_batch_dependent_logps(
6867
pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3))
6968
7069
marginalize(idx)
71-
dims_connections = [(1, 0, None), (None, 0, 1)]
72-
"""
7370
74-
dims_connections = marginalized_op.dims_connections
71+
The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)]
72+
which tells us we need to reduce the last axis of dep1 logp and the first of dep2,
73+
as well as transpose the remaining axis of dep1 logp before adding the two elemwise.
74+
"""
7575

76-
reduced_logps = [marginalized_logp]
77-
for dependent_op, dependent_logp, dims_connection in zip(dependent_ops, dependent_logps, dims_connections):
76+
reduced_logps = []
77+
for dependent_op, dependent_logp, dependent_dims_connection in zip(dependent_ops, dependent_logps, dependent_dims_connections):
7878
if dependent_logp.type.ndim > 0:
7979
# Find which support axis implied by the MarginalRV need to be reduced
8080
# Some may have already been reduced by the logp expression of the dependent RV, for non-univariate RVs
@@ -88,27 +88,27 @@ def _reduce_batch_dependent_logps(
8888
# Dependent RV support axes are already collapsed in the logp, so we ignore them
8989
supp_axes = [
9090
-i
91-
for i, dim in enumerate(reversed(dims_connection), start=1)
91+
for i, dim in enumerate(reversed(dependent_dims_connection), start=1)
9292
if (dim == () and -i not in dep_supp_axes)
9393
]
9494

9595
dependent_logp = dependent_logp.sum(supp_axes)
96-
assert dependent_logp.type.ndim == marginalized_logp.type.ndim
9796

9897
# Finally, we need to align the dependent logp batch dimensions with the marginalized logp
99-
dims_alignment = [dim[0] for dim in dims_connection if dim != ()]
98+
dims_alignment = [dim[0] for dim in dependent_dims_connection if dim != ()]
10099
dependent_logp = dependent_logp.transpose(*dims_alignment)
101100

102101
reduced_logps.append(dependent_logp)
103102

104103
reduced_logp = pt.add(*reduced_logps)
104+
return reduced_logp
105105

106-
if reduced_logp.type.ndim > 0:
106+
def _align_logp_with_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable:
107+
if logp.type.ndim > 0:
107108
# Transpose reduced logp into the direction of the first dependent RV
108-
first_dims_alignment = [dim[0] for dim in dims_connections[0] if dim != ()]
109-
reduced_logp = reduced_logp.transpose(*first_dims_alignment)
110-
111-
return reduced_logp
109+
dims_alignment = [dim[0] for dim in dims if dim != ()]
110+
logp = logp.transpose(*dims_alignment)
111+
return logp
112112

113113

114114
dummy_zero = pt.constant(0, name="dummy_zero")
@@ -131,9 +131,8 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
131131

132132
# Reduce logp dimensions corresponding to broadcasted variables
133133
marginalized_logp = logps_dict.pop(marginalized_vv)
134-
joint_logp = _reduce_batch_dependent_logps(
135-
marginalized_op=op,
136-
marginalized_logp=marginalized_logp,
134+
joint_logp = marginalized_logp + _reduce_batch_dependent_logps(
135+
dependent_dims_connections=op.dims_connections,
137136
dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs],
138137
dependent_logps=[logps_dict[value] for value in values],
139138
)
@@ -174,8 +173,12 @@ def logp_fn(marginalized_rv_const, *non_sequences):
174173

175174
joint_logp = pt.logsumexp(joint_logps, axis=0)
176175

176+
# Align logp with non-collapsed batch dimensions of first RV
177+
joint_logp = _align_logp_with_dims(dims=op.dims_connections[0], logp=joint_logp)
178+
177179
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
178-
return joint_logp, *((dummy_zero,) * (len(values) - 1))
180+
dummy_logps = ((dummy_zero,) * (len(values) - 1))
181+
return joint_logp, *dummy_logps
179182

180183

181184
@_logprob.register(MarginalDiscreteMarkovChainRV)
@@ -200,8 +203,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
200203
logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values)))
201204

202205
# Reduce and add the batch dims beyond the chain dimension
203-
reduced_logp_emissions = _add_reduce_batch_dependent_logps(
204-
chain_rv.type, logp_emissions_dict.values()
206+
reduced_logp_emissions = _reduce_batch_dependent_logps(
207+
dependent_dims_connections=op.dims_connections,
208+
dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs],
209+
dependent_logps=[logp_emissions_dict[value] for value in values],
205210
)
206211

207212
# Add a batch dimension for the domain of the chain
@@ -240,9 +245,13 @@ def step_alpha(logp_emission, log_alpha, log_P):
240245
# Final logp is just the sum of the last scan state
241246
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
242247

243-
# TODO: Transpose into shape of first emission
248+
# Align logp with non-collapsed batch dimensions of first RV
249+
remaining_dims_first_emission = list(op.dims_connections[0])
250+
# The last dim of chain_rv was removed when computing the logp
251+
remaining_dims_first_emission.remove((chain_rv.type.ndim - 1,))
252+
joint_logp = _align_logp_with_dims(remaining_dims_first_emission, joint_logp)
244253

245254
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
246255
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
247-
dummy_logps = (dummy_zero) * (len(values) - 1)
256+
dummy_logps = (dummy_zero,) * (len(values) - 1)
248257
return joint_logp, *dummy_logps

pymc_experimental/model/marginal/graph_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _broadcast_dims(
107107
raise ValueError("Different known dimensions mixed via broadcasting")
108108

109109
if len(set(output_dim[0] for output_dim in output_dims if output_dim != ())) < len([output_dim for output_dim in output_dims if output_dim != ()]):
110-
raise ValueError("Same dimension used in different axis after broadcasting")
110+
raise ValueError("Same known dimension used in different axis after broadcasting")
111111

112112
return output_dims
113113

pymc_experimental/model/marginal/marginal_model.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
MarginalDiscreteMarkovChainRV,
2828
MarginalFiniteDiscreteRV,
2929
_reduce_batch_dependent_logps,
30-
get_domain_of_finite_discrete_rv,
30+
get_domain_of_finite_discrete_rv, _align_logp_with_dims,
3131
)
3232
from pymc_experimental.model.marginal.graph_analysis import (
3333
collect_shared_vars,
@@ -423,18 +423,32 @@ def transform_input(inputs):
423423
m = self.clone()
424424
marginalized_rv = m.vars_to_clone[marginalized_rv]
425425
m.unmarginalize([marginalized_rv])
426-
dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
427-
joint_logps = m.logp(vars=[marginalized_rv, *dependent_vars], sum=False)
426+
dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs)
427+
logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False)
428+
429+
if marginalized_rv.type.ndim > 0:
430+
other_direct_rv_ancestors = [
431+
rv
432+
for rv in find_conditional_input_rvs(dependent_rvs, self.basic_RVs)
433+
if rv is not marginalized_rv
434+
]
435+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
436+
marginalized_rv, other_direct_rv_ancestors, dependent_rvs
437+
)
438+
# Handle batch dims for marginalized value and its dependent RVs
439+
marginalized_logp, *dependent_logps = logps
440+
joint_logp = marginalized_logp + _reduce_batch_dependent_logps(
441+
dependent_rvs_dim_connections,
442+
[dependent_var.owner.op for dependent_var in dependent_rvs],
443+
dependent_logps
444+
)
445+
else:
446+
joint_logp = pt.sum([logp.sum() for logp in logps])
447+
428448

429449
marginalized_value = m.rvs_to_values[marginalized_rv]
430450
other_values = [v for v in m.value_vars if v is not marginalized_value]
431451

432-
# Handle batch dims for marginalized value and its dependent RVs
433-
marginalized_logp, *dependent_logps = joint_logps
434-
joint_logp = marginalized_logp + _reduce_batch_dependent_logps(
435-
marginalized_rv.type, dependent_logps
436-
)
437-
438452
rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
439453
rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
440454
rv_domain_tensor = pt.moveaxis(
@@ -447,37 +461,30 @@ def transform_input(inputs):
447461
0,
448462
)
449463

450-
joint_logps = vectorize_graph(
464+
batched_joint_logp = vectorize_graph(
451465
joint_logp,
452466
replace={marginalized_value: rv_domain_tensor},
453467
)
454-
joint_logps = pt.moveaxis(joint_logps, 0, -1)
468+
batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1)
455469

456-
rv_loglike_fn = None
457-
joint_logps_norm = log_softmax(joint_logps, axis=-1)
470+
joint_logp_norm = log_softmax(batched_joint_logp, axis=-1)
458471
if return_samples:
459-
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
472+
rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp)
460473
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
461-
sample_rv_outs += rv_domain[0]
462-
463-
rv_loglike_fn = compile_pymc(
464-
inputs=other_values,
465-
outputs=[joint_logps_norm, sample_rv_outs],
466-
on_unused_input="ignore",
467-
random_seed=seed,
468-
)
474+
rv_draws += rv_domain[0]
475+
outputs = [joint_logp_norm, rv_draws]
469476
else:
470-
rv_loglike_fn = compile_pymc(
471-
inputs=other_values,
472-
outputs=joint_logps_norm,
473-
on_unused_input="ignore",
474-
random_seed=seed,
475-
)
477+
outputs = joint_logp_norm
478+
479+
rv_loglike_fn = compile_pymc(
480+
inputs=other_values,
481+
outputs=outputs,
482+
on_unused_input="ignore",
483+
random_seed=seed,
484+
)
476485

477486
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
478487

479-
logps = None
480-
samples = None
481488
if return_samples:
482489
logps, samples = zip(*logvs)
483490
logps = np.array(logps)

tests/model/marginal/test_distributions.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
7676
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2)
7777
if categorical_emission:
7878
emission = pm.Categorical(
79-
"emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6])
79+
"emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain]
8080
)
8181
else:
8282
emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6))
@@ -88,29 +88,44 @@ def test_marginalized_hmm_categorical_emission(categorical_emission):
8888
np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp)
8989

9090

91+
@pytest.mark.parametrize("batch_chain", (False, True))
9192
@pytest.mark.parametrize("batch_emission1", (False, True))
9293
@pytest.mark.parametrize("batch_emission2", (False, True))
93-
def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2):
94-
emission1_shape = (2, 4) if batch_emission1 else (4,)
95-
emission2_shape = (2, 4) if batch_emission2 else (4,)
94+
def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2):
95+
chain_shape = (3, 1, 4) if batch_chain else (4,)
96+
emission1_shape = (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape))
97+
emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape
9698
with MarginalModel() as m:
9799
P = [[0, 1], [1, 0]]
98100
init_dist = pm.Categorical.dist(p=[1, 0])
99-
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3)
100-
emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape)
101+
chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape)
102+
emission_1 = pm.Normal("emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape)
103+
104+
emission2_mu = ((1 - chain) * 2 - 1)
105+
if batch_emission2:
106+
emission2_mu = emission2_mu[..., None]
101107
emission_2 = pm.Normal(
102-
"emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape
108+
"emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape
103109
)
104110

105111
with pytest.warns(UserWarning, match="multiple dependent variables"):
106112
m.marginalize([chain])
107113

108-
logp_fn = m.compile_logp()
114+
logp_fn = m.compile_logp(sum=False)
109115

110116
test_value = np.array([-1, 1, -1, 1])
111117
multiplier = 2 + batch_emission1 + batch_emission2
118+
if batch_chain:
119+
multiplier *= 3
112120
expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier
113-
test_value_emission1 = np.broadcast_to(test_value, emission1_shape)
114-
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
121+
122+
test_value = np.broadcast_to(test_value, chain_shape)
123+
test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape)
124+
if batch_emission2:
125+
test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape)
126+
else:
127+
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
115128
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
116-
np.testing.assert_allclose(logp_fn(test_point), expected_logp)
129+
res_logp, dummy_logp = logp_fn(test_point)
130+
assert res_logp.shape == ((1, 3) if batch_chain else ())
131+
np.testing.assert_allclose(res_logp.sum(), expected_logp)

tests/model/marginal/test_graph_analysis.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def test_advanced_subtensor_key(self):
105105

106106
# Mix keys dimensions
107107
out = base[:, inp, inp.T]
108-
[dims] = subgraph_batch_dim_connection(inp, [], [out])
109-
assert dims == ((), (0, 1), (0, 1))
108+
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
109+
subgraph_batch_dim_connection(inp, [], [out])
110110

111111
def test_elemwise(self):
112112
inp = pt.tensor(shape=(5, 5))
@@ -116,11 +116,13 @@ def test_elemwise(self):
116116
assert dims == ((0,), (1,))
117117

118118
out = inp + inp.T
119-
[dims] = subgraph_batch_dim_connection(inp, [], [out])
120-
assert dims == (
121-
(0, 1),
122-
(0, 1),
123-
)
119+
with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"):
120+
subgraph_batch_dim_connection(inp, [], [out])
121+
122+
out = inp[None, :, None, :] + inp[:, None, :, None]
123+
with pytest.raises(ValueError, match="Same known dimension used in different axis after broadcasting"):
124+
subgraph_batch_dim_connection(inp, [], [out])
125+
124126

125127
def test_blockwise(self):
126128
inp = pt.tensor(shape=(5, 4))

tests/model/marginal/test_marginal_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def test_batched(self):
802802
with MarginalModel() as m:
803803
sigma = pm.HalfNormal("sigma")
804804
idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
805-
y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2))
805+
y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3))
806806

807807
m.marginalize([idx])
808808

@@ -820,10 +820,9 @@ def test_batched(self):
820820

821821
idata = m.recover_marginals(idata, return_samples=True)
822822
post = idata.posterior
823-
assert "idx" in post
824-
assert "lp_idx" in post
825-
assert post.idx.shape == post.y.shape
826-
assert post.lp_idx.shape == (*post.idx.shape, 2)
823+
assert post["y"].shape == (1, 20, 2, 3)
824+
assert post["idx"].shape == (1, 20, 3, 2)
825+
assert post["lp_idx"].shape == (1, 20, 3, 2, 2)
827826

828827
def test_nested(self):
829828
"""Test that marginalization works when there are nested marginalized RVs"""

0 commit comments

Comments
 (0)