Skip to content

Commit 2c81057

Browse files
committed
.WIP backtrack
1 parent 707aca8 commit 2c81057

File tree

3 files changed

+98
-64
lines changed

3 files changed

+98
-64
lines changed

pymc_experimental/model/marginal/graph_analysis.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import chain, zip_longest
33

44
from pymc import SymbolicRandomVariable
5+
from pymc.distributions.custom import CustomSymbolicDistRV
56
from pytensor.compile import SharedVariable
67
from pytensor.graph import Constant, Variable, ancestors, graph_inputs
78
from pytensor.graph.basic import io_toposort
@@ -112,7 +113,6 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
112113

113114
if not any(inputs_dims):
114115
# None of the inputs are related to the batch_axes of the marginalized_rv
115-
# We could set `()` for everything, but for now that doesn't seem needed
116116
continue
117117

118118
elif isinstance(node.op, DimShuffle):
@@ -122,7 +122,11 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
122122
)
123123
var_dims[node.outputs[0]] = output_dims
124124

125-
elif isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None:
125+
elif (
126+
isinstance(node.op, CustomSymbolicDistRV)
127+
or isinstance(node.op, SymbolicRandomVariable)
128+
and node.op.extended_signature is None
129+
):
126130
# SymbolicRandomVariables without signature are a wild-card, so we need to introspect the inner graph.
127131
# MarginalRVs are such a case!
128132
inner_var_dims = {
@@ -154,23 +158,18 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
154158
missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim)
155159
inputs_dims[param_idx] = ((),) * missing_ndim + param_dims
156160

157-
# Collapse all core_dims
158-
core_dims = tuple(
159-
sorted(
160-
chain.from_iterable(
161-
[i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]
162-
)
161+
if any(any(input_dim[op_batch_ndim:]) for input_dim in inputs_dims):
162+
raise ValueError(
163+
f"Use of known dimensions as core dimensions of op {node.op} not supported."
163164
)
164-
)
165+
165166
batch_dims = _broadcast_dims(
166167
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
167168
)
168-
# Add batch dims to each output_dims
169-
batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
170169
for out in node.outputs:
171170
if isinstance(out.type, TensorType):
172171
core_ndim = out.type.ndim - op_batch_ndim
173-
output_dims = batch_dims + (core_dims,) * core_ndim
172+
output_dims = batch_dims + ((),) * core_ndim
174173
var_dims[out] = output_dims
175174

176175
elif isinstance(node.op, CAReduce):
@@ -182,8 +181,12 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
182181
elif axes is None:
183182
axes = tuple(range(node.inputs[0].type.ndim))
184183

185-
# Output dims contain the collapsed dims
186-
output_dims = [dims + axes for i, dims in enumerate(input_dims) if i not in axes]
184+
if any(input_dims[axis] for axis in axes):
185+
raise ValueError(
186+
f"Use of known dimensions as reduced dimensions of op {node.op} not supported."
187+
)
188+
189+
output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes]
187190
var_dims[node.outputs[0]] = tuple(output_dims)
188191

189192
elif isinstance(node.op, Subtensor):
@@ -198,7 +201,7 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
198201
# Dim is kept
199202
output_dims.append(value_dims)
200203
elif value_dims:
201-
raise NotImplementedError(
204+
raise ValueError(
202205
"Partial slicing or indexing of known dimensions not supported."
203206
)
204207
elif isinstance(idx, slice):
@@ -242,7 +245,7 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
242245
non_adv_dims.append(value_dim)
243246
elif value_dim:
244247
# We are trying to partially slice or index a known dimension
245-
raise NotImplementedError(
248+
raise ValueError(
246249
"Partial slicing or advanced integer indexing of known dimensions not supported."
247250
)
248251
elif isinstance(idx, slice):
@@ -277,16 +280,18 @@ def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR
277280
return var_dims
278281

279282

280-
def subgraph_dim_connection(
283+
def subgraph_batch_dim_connection(
281284
input_var, other_inputs, output_vars
282285
) -> list[tuple[tuple[int, ...], ...]]:
283-
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
286+
"""Identify how the batch dims of rv_to_marginalize map to the batch dimensions of the output_rvs.
284287
285288
Raises
286289
------
290+
ValueError
291+
If input batch dimensions are mixed in the graph leading to output_vars.
292+
287293
NotImplementedError
288294
If variable related to marginalized batch_dims is used in an operation that is not yet supported
289-
290295
"""
291296
var_dims = {input_var: tuple((i,) for i in range(input_var.type.ndim))}
292297
var_dims = _subgraph_dim_connection(var_dims, [input_var, *other_inputs], output_vars)

pymc_experimental/model/marginal/marginal_model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
find_conditional_dependent_rvs,
3535
find_conditional_input_rvs,
3636
is_conditional_dependent,
37-
subgraph_dim_connection,
37+
subgraph_batch_dim_connection,
3838
)
3939

4040
ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]
@@ -566,9 +566,15 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
566566
# If the marginalized RV has multiple dimensions, check that graph between
567567
# marginalized RV and dependent RVs does not mix information from batch dimensions
568568
# (otherwise logp would require enumerating over all combinations of batch dimension values)
569-
dependent_rvs_dim_connections = subgraph_dim_connection(
570-
rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs
571-
)
569+
try:
570+
dependent_rvs_dim_connections = subgraph_batch_dim_connection(
571+
rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs
572+
)
573+
except ValueError as e:
574+
# For the perspective of the user this is a NotImplementedError
575+
raise NotImplementedError(
576+
"The graph between the marginalized and dependent RVs cannot be marginalized"
577+
) from e
572578

573579
if any(
574580
len(dim) > 1

tests/model/marginal/test_graph_analysis.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,46 @@
22
import pytest
33

44
from pymc.distributions import CustomDist
5+
from pytensor.tensor.type_other import NoneTypeT
56

6-
from pymc_experimental.model.marginal.graph_analysis import subgraph_dim_connection
7+
from pymc_experimental.model.marginal.graph_analysis import subgraph_batch_dim_connection
78

89

9-
class TestSubgraphDimConnection:
10+
class TestSubgraphBatchDimConnection:
1011
def test_dimshuffle(self):
1112
inp = pt.tensor(shape=(5, 1, 4, 3))
1213
out1 = pt.matrix_transpose(inp)
1314
out2 = pt.expand_dims(inp, 1)
1415
out3 = pt.squeeze(inp)
15-
[dims1, dims2, dims3] = subgraph_dim_connection(inp, [], [out1, out2, out3])
16+
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [], [out1, out2, out3])
1617
assert dims1 == ((0,), (1,), (3,), (2,))
1718
assert dims2 == ((0,), (), (1,), (2,), (3,))
1819
assert dims3 == ((0,), (2,), (3,))
1920

2021
def test_careduce(self):
2122
inp = pt.tensor(shape=(4, 3, 2))
22-
out = pt.sum(inp, axis=(1,))
23-
[dims] = subgraph_dim_connection(inp, [], [out])
24-
assert dims == ((0, 1), (2, 1))
23+
24+
out = pt.sum(inp[:, None], axis=(1,))
25+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
26+
assert dims == ((0,), (1,), (2,))
27+
28+
invalid_out = pt.sum(inp, axis=(1,))
29+
with pytest.raises(ValueError, match="Use of known dimensions"):
30+
subgraph_batch_dim_connection(inp, [], [invalid_out])
2531

2632
def test_subtensor(self):
2733
inp = pt.tensor(shape=(4, 3, 2))
2834

2935
invalid_out = inp[0, :1]
3036
with pytest.raises(
31-
NotImplementedError,
37+
ValueError,
3238
match="Partial slicing or indexing of known dimensions not supported",
3339
):
34-
subgraph_dim_connection(inp, [], [invalid_out])
40+
subgraph_batch_dim_connection(inp, [], [invalid_out])
3541

3642
# If we are selecting dummy / unknown dimensions that's fine
3743
valid_out = pt.expand_dims(inp, (0, 1))[0, :1]
38-
[dims] = subgraph_dim_connection(inp, [], [valid_out])
44+
[dims] = subgraph_batch_dim_connection(inp, [], [valid_out])
3945
assert dims == ((), (0,), (1,), (2,))
4046

4147
def test_advanced_subtensor_value(self):
@@ -44,99 +50,116 @@ def test_advanced_subtensor_value(self):
4450

4551
# Index on an unlabled dim introduced by broadcasting with zeros
4652
out = intermediate_out[:, [0, 0, 1, 2]]
47-
[dims] = subgraph_dim_connection(inp, [], [out])
53+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
4854
assert dims == ((0,), (), (1,), ())
4955

5056
# Indexing that introduces more dimensions
5157
out = intermediate_out[:, [[0, 0], [1, 2]], :]
52-
[dims] = subgraph_dim_connection(inp, [], [out])
58+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
5359
assert dims == ((0,), (), (), (1,), ())
5460

5561
# Special case where advanced dims are moved to the front of the output
5662
out = intermediate_out[:, [0, 0, 1, 2], :, 0]
57-
[dims] = subgraph_dim_connection(inp, [], [out])
63+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
5864
assert dims == ((), (0,), (1,))
5965

6066
# Indexing on a labeled dim fails
6167
out = intermediate_out[:, :, [0, 0, 1, 2]]
62-
with pytest.raises(
63-
NotImplementedError, match="Partial slicing or advanced integer indexing"
64-
):
65-
subgraph_dim_connection(inp, [], [out])
68+
with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"):
69+
subgraph_batch_dim_connection(inp, [], [out])
6670

6771
def test_advanced_subtensor_key(self):
6872
inp = pt.tensor(shape=(5, 5), dtype=int)
6973
base = pt.zeros((2, 3, 4))
7074

7175
out = base[inp]
72-
[dims] = subgraph_dim_connection(inp, [], [out])
76+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
7377
assert dims == ((0,), (1,), (), ())
7478

7579
out = base[:, :, inp]
76-
[dims] = subgraph_dim_connection(inp, [], [out])
80+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
7781
assert dims == ((), (), (0,), (1,))
7882

7983
out = base[1:, 0, inp]
80-
[dims] = subgraph_dim_connection(inp, [], [out])
84+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
8185
assert dims == ((), (0,), (1,))
8286

8387
# Special case where advanced dims are moved to the front of the output
8488
out = base[0, :, inp]
85-
[dims] = subgraph_dim_connection(inp, [], [out])
89+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
8690
assert dims == ((0,), (1,), ())
8791

8892
# Mix keys dimensions
8993
out = base[:, inp, inp.T]
90-
[dims] = subgraph_dim_connection(inp, [], [out])
94+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
9195
assert dims == ((), (0, 1), (0, 1))
9296

9397
def test_elemwise(self):
9498
inp = pt.tensor(shape=(5, 5))
9599

96100
out = inp + inp
97-
[dims] = subgraph_dim_connection(inp, [], [out])
101+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
98102
assert dims == ((0,), (1,))
99103

100104
out = inp + inp.T
101-
[dims] = subgraph_dim_connection(inp, [], [out])
105+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
102106
assert dims == (
103107
(0, 1),
104-
(
105-
0,
106-
1,
107-
),
108+
(0, 1),
108109
)
109110

110111
def test_blockwise(self):
111-
inp = pt.tensor(shape=(5, 4, 3, 2))
112-
out = inp @ pt.ones((2, 3))
113-
[dims] = subgraph_dim_connection(inp, [], [out])
114-
# Every dimension contains information from the core dimensions
115-
assert dims == ((0, 2, 3), (1, 2, 3), (2, 3), (2, 3))
112+
inp = pt.tensor(shape=(5, 4))
113+
114+
invalid_out = inp @ pt.ones((4, 3))
115+
with pytest.raises(ValueError, match="Use of known dimensions"):
116+
subgraph_batch_dim_connection(inp, [], [invalid_out])
117+
118+
out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
119+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
120+
assert dims == ((0,), (1,), (), ())
116121

117122
def test_random_variable(self):
118123
inp = pt.tensor(shape=(5, 4, 3))
124+
119125
out1 = pt.random.normal(loc=inp)
120-
out2 = pt.random.categorical(p=inp)
121-
out3 = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
122-
[dims1, dims2, dims3] = subgraph_dim_connection(inp, [], [out1, out2, out3])
126+
out2 = pt.random.categorical(p=inp[..., None])
127+
out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1))
128+
[dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [], [out1, out2, out3])
123129
assert dims1 == ((0,), (1,), (2,))
124-
assert dims2 == ((0, 2), (1, 2))
125-
assert dims3 == ((0, 2), (1, 2), (2,))
130+
assert dims2 == ((0,), (1,), (2,))
131+
assert dims3 == ((0,), (1,), (2,), ())
132+
133+
invalid_out = pt.random.categorical(p=inp)
134+
with pytest.raises(ValueError, match="Use of known dimensions"):
135+
subgraph_batch_dim_connection(inp, [], [invalid_out])
136+
137+
invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3))
138+
with pytest.raises(ValueError, match="Use of known dimensions"):
139+
subgraph_batch_dim_connection(inp, [], [invalid_out])
126140

127141
def test_symbolic_random_variable(self):
128142
inp = pt.tensor(shape=(4, 3, 2))
143+
144+
# Test univariate
129145
out = CustomDist.dist(
130146
inp,
131147
dist=lambda mu, size: pt.random.normal(loc=mu, size=size),
132148
)
133-
[dims] = subgraph_dim_connection(inp, [], [out])
149+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
134150
assert dims == ((0,), (1,), (2,))
135151

136152
# Test multivariate
153+
def dist(mu, size):
154+
if isinstance(size.type, NoneTypeT):
155+
size = mu.shape
156+
return pt.random.normal(loc=mu[..., None], size=(*size, 2)).sum(-1)
157+
137158
out = CustomDist.dist(
138159
inp,
139-
dist=lambda mu, size: pt.random.normal(loc=mu, size=size).sum(-1),
160+
dist=dist,
161+
size=(4, 3, 2),
162+
ndim_supp=1,
140163
)
141-
[dims] = subgraph_dim_connection(inp, [], [out])
142-
assert dims == ((0, 2), (1, 2))
164+
[dims] = subgraph_batch_dim_connection(inp, [], [out])
165+
assert dims == ((0,), (1,), (2,))

0 commit comments

Comments
 (0)