Skip to content

Commit b8e939e

Browse files
committed
Fix local_subtensor_rv_lift rewrite bug with vector parameters
Also allow rewrite to work with multivariate variables, when indexing does not act on support dims.
1 parent bfeabc8 commit b8e939e

File tree

2 files changed

+159
-103
lines changed

2 files changed

+159
-103
lines changed

pytensor/tensor/random/rewriting/basic.py

Lines changed: 62 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from itertools import zip_longest
2+
13
from pytensor.compile import optdb
24
from pytensor.configdefaults import config
35
from pytensor.graph.op import compute_test_value
46
from pytensor.graph.rewriting.basic import in2out, node_rewriter
7+
from pytensor.tensor import NoneConst
58
from pytensor.tensor.basic import constant, get_vector_length
69
from pytensor.tensor.elemwise import DimShuffle
710
from pytensor.tensor.extra_ops import broadcast_to
@@ -17,6 +20,7 @@
1720
get_idx_list,
1821
indexed_result_shape,
1922
)
23+
from pytensor.tensor.type_other import SliceType
2024

2125

2226
def is_rv_used_in_graph(base_rv, node, fgraph):
@@ -196,141 +200,104 @@ def local_dimshuffle_rv_lift(fgraph, node):
196200
def local_subtensor_rv_lift(fgraph, node):
197201
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
198202
199-
In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions
200-
need to be separated into distinct replication-space and (independent)
201-
parameter-space ``*Subtensor``s.
202-
203-
The replication-space ``*Subtensor`` can be used to determine a
204-
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
205-
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
206-
applied to the distribution parameters.
207-
208-
Consider the following example graph:
209-
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
210-
``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``,
211-
which correspond to all three ``size`` dimensions. Now, depending on the
212-
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op``
213-
could be reducing the ``size`` parameter and/or sub-setting the independent
214-
``mu`` and ``std`` parameters. Only once the dimensions are properly
215-
separated into the two replication/parameter subspaces can we determine how
216-
the ``*Subtensor`` indices are distributed.
217-
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
218-
could become
219-
``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
220-
if ``mu.shape == std.shape == ()``
221-
222-
``normal`` is a rather simple case, because it's univariate. Multivariate
223-
cases require a mapping between the parameter space and the image of the
224-
random variable. This may not always be possible, but for many common
225-
distributions it is. For example, the dimensions of the multivariate
226-
normal's image can be mapped directly to each dimension of its parameters.
227-
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
228-
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``.
203+
For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``.
229204
205+
This rewrite also applies to multivariate distributions as long
206+
as indexing does not happen within core dimensions, such as in
207+
``mvnormal(mu, cov, size=(2,))[0, 0]``.
230208
"""
231209

232210
st_op = node.op
233211

234212
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
235213
return False
236214

237-
base_rv = node.inputs[0]
215+
rv = node.inputs[0]
216+
rv_node = rv.owner
238217

239-
rv_node = base_rv.owner
240218
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
241219
return False
242220

243-
# If no one else is using the underlying `RandomVariable`, then we can
244-
# do this; otherwise, the graph would be internally inconsistent.
245-
if is_rv_used_in_graph(base_rv, node, fgraph):
246-
return False
247-
248221
rv_op = rv_node.op
249222
rng, size, dtype, *dist_params = rv_node.inputs
250223

251-
# TODO: Remove this once the multi-dimensional changes described below are
252-
# in place.
253-
if rv_op.ndim_supp > 0:
254-
return False
255-
256-
rv_op = base_rv.owner.op
257-
rng, size, dtype, *dist_params = base_rv.owner.inputs
258-
224+
# Parse indices
259225
idx_list = getattr(st_op, "idx_list", None)
260226
if idx_list:
261227
cdata = get_idx_list(node.inputs, idx_list)
262228
else:
263229
cdata = node.inputs[1:]
264-
265230
st_indices, st_is_bool = zip(
266231
*tuple(
267232
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
268233
)
269234
)
270235

271-
# We need to separate dimensions into replications and independents
272-
num_ind_dims = None
273-
if len(dist_params) == 1:
274-
num_ind_dims = dist_params[0].ndim
275-
else:
276-
# When there is more than one distribution parameter, assume that all
277-
# of them will broadcast to the maximum number of dimensions
278-
num_ind_dims = max(d.ndim for d in dist_params)
279-
280-
reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp)
281-
282-
if len(st_indices) > reps_ind_split_idx:
283-
# These are the indices that need to be applied to the parameters
284-
ind_indices = tuple(st_indices[reps_ind_split_idx:])
285-
286-
# We need to broadcast the parameters before applying the `*Subtensor*`
287-
# with these indices, because the indices could be referencing broadcast
288-
# dimensions that don't exist (yet)
289-
bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params)
290-
291-
# TODO: For multidimensional distributions, we need a map that tells us
292-
# which dimensions of the parameters need to be indexed.
293-
#
294-
# For example, `multivariate_normal` would have the following:
295-
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
296-
#
297-
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
298-
# the dimension of the RV's image, and its second parameter's
299-
# (i.e. covariance's) first and second dimensions map directly to the
300-
# dimension of the RV's image.
301-
302-
args_lifted = tuple(p[ind_indices] for p in bcast_dist_params)
303-
else:
304-
# In this case, no indexing is applied to the parameters; only the
305-
# `size` parameter is affected.
306-
args_lifted = dist_params
236+
# Check that indexing does not act on support dims
237+
batched_ndims = rv.ndim - rv_op.ndim_supp
238+
if len(st_indices) > batched_ndims:
239+
# If the last indexes are just dummy `slice(None)` we discard them
240+
st_is_bool = st_is_bool[:batched_ndims]
241+
st_indices, supp_indices = (
242+
st_indices[:batched_ndims],
243+
st_indices[batched_ndims:],
244+
)
245+
for index in supp_indices:
246+
if not (
247+
isinstance(index.type, SliceType)
248+
and all(NoneConst.equals(i) for i in index.owner.inputs)
249+
):
250+
return False
251+
252+
# If no one else is using the underlying `RandomVariable`, then we can
253+
# do this; otherwise, the graph would be internally inconsistent.
254+
if is_rv_used_in_graph(rv, node, fgraph):
255+
return False
307256

257+
# Update the size to reflect the indexed dimensions
308258
# TODO: Could use `ShapeFeature` info. We would need to be sure that
309259
# `node` isn't in the results, though.
310260
# if hasattr(fgraph, "shape_feature"):
311261
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
312262
# else:
313-
output_shape = indexed_result_shape(base_rv.shape, st_indices)
314-
315-
size_lifted = (
316-
output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp]
263+
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
264+
new_size_ignoring_boolean = (
265+
output_shape_ignoring_bool
266+
if rv_op.ndim_supp == 0
267+
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
317268
)
318269

319-
# Boolean indices can actually change the `size` value (compared to just
320-
# *which* dimensions of `size` are used).
270+
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
271+
# The `indexed_result_shape` helper does not consider this
321272
if any(st_is_bool):
322-
size_lifted = tuple(
273+
new_size = tuple(
323274
at_sum(idx) if is_bool else s
324-
for s, is_bool, idx in zip(
325-
size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)]
275+
for s, is_bool, idx in zip_longest(
276+
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
326277
)
327278
)
279+
else:
280+
new_size = new_size_ignoring_boolean
328281

329-
new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted)
330-
_, new_rv = new_node.outputs
282+
# Update the parameters to reflect the indexed dimensions
283+
new_dist_params = []
284+
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
285+
# Apply indexing on the batched dimensions of the parameter
286+
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
287+
batched_param = shape_padleft(param, batched_param_dims_missing)
288+
batched_st_indices = []
289+
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
290+
# If we have a degenerate dimension indexing it should always do the job
291+
if batched_param_shape == 1:
292+
batched_st_indices.append(0)
293+
else:
294+
batched_st_indices.append(st_index)
295+
new_dist_params.append(batched_param[tuple(batched_st_indices)])
296+
297+
# Create new RV
298+
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
299+
new_rv = new_node.default_output()
331300

332-
# Calling `Op.make_node` directly circumvents test value computations, so
333-
# we need to compute the test values manually
334301
if config.compute_test_value != "off":
335302
compute_test_value(new_node)
336303

tests/tensor/random/rewriting/test_basic.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.tensor import constant
1313
from pytensor.tensor.elemwise import DimShuffle
1414
from pytensor.tensor.random.basic import (
15+
categorical,
1516
dirichlet,
1617
multinomial,
1718
multivariate_normal,
@@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv(
3637
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
3738
):
3839
dist_params_at = []
39-
for p in dist_params:
40-
p_at = at.as_tensor(p).type()
40+
for i, p in enumerate(dist_params):
41+
p_at = at.as_tensor(p).type(f"p_{i}")
4142
p_at.tag.test_value = p
4243
dist_params_at.append(p_at)
4344

@@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
495496
),
496497
(3, 2, 2),
497498
),
498-
# A multi-dimensional case
499+
# Only one distribution parameter
500+
(
501+
(0,),
502+
True,
503+
poisson,
504+
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
505+
(3, 2, 2),
506+
),
507+
# Univariate distribution with vector parameters
508+
(
509+
(np.array([0, 2]),),
510+
True,
511+
categorical,
512+
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
513+
(4,),
514+
),
515+
(
516+
(np.array([True, False, True, True]),),
517+
True,
518+
categorical,
519+
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
520+
(4,),
521+
),
522+
(
523+
(np.array([True, False, True]),),
524+
True,
525+
categorical,
526+
(
527+
np.array(
528+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
529+
dtype=config.floatX,
530+
),
531+
),
532+
(),
533+
),
534+
(
535+
(
536+
slice(None),
537+
np.array([True, False, True]),
538+
),
539+
True,
540+
categorical,
541+
(
542+
np.array(
543+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
544+
dtype=config.floatX,
545+
),
546+
),
547+
(4, 3),
548+
),
549+
# Boolean indexing where output is empty
550+
(
551+
(np.array([False, False]),),
552+
True,
553+
normal,
554+
(np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),),
555+
(2, 3),
556+
),
499557
(
558+
(np.array([False, False]),),
559+
True,
560+
categorical,
561+
(
562+
np.array(
563+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
564+
dtype=config.floatX,
565+
),
566+
),
567+
(2, 3),
568+
),
569+
# Multivariate cases, indexing only supported if it does not affect core dimensions
570+
(
571+
# Indexing dips into core dimension
500572
(np.array([1]), 0),
501573
False,
502574
multivariate_normal,
@@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
506578
),
507579
(),
508580
),
509-
# Only one distribution parameter
510581
(
511-
(0,),
582+
(np.array([0, 2]),),
512583
True,
513-
poisson,
514-
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
515-
(3, 2, 2),
584+
multivariate_normal,
585+
(
586+
np.array(
587+
[[-100, -125, -150], [0, 0, 0], [200, 225, 250]],
588+
dtype=config.floatX,
589+
),
590+
np.eye(3, dtype=config.floatX) * 1e-6,
591+
),
592+
(),
593+
),
594+
(
595+
(np.array([True, False, True]), slice(None)),
596+
True,
597+
multivariate_normal,
598+
(
599+
np.array([200, 250], dtype=config.floatX),
600+
# Second covariance is invalid, to test it is not chosen
601+
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
602+
* 1e-6,
603+
),
604+
(3,),
516605
),
517606
],
518607
)

0 commit comments

Comments
 (0)