Skip to content

Commit d5ef9bd

Browse files
committed
Rewrite Blockwise IncSubtensor
Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before
1 parent 55d39ed commit d5ef9bd

File tree

2 files changed

+329
-137
lines changed

2 files changed

+329
-137
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 142 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ScalarFromTensor,
2525
TensorFromScalar,
2626
alloc,
27+
arange,
2728
cast,
2829
concatenate,
2930
expand_dims,
@@ -34,9 +35,10 @@
3435
switch,
3536
)
3637
from pytensor.tensor.basic import constant as tensor_constant
37-
from pytensor.tensor.blockwise import Blockwise
38+
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
3839
from pytensor.tensor.elemwise import Elemwise
3940
from pytensor.tensor.exceptions import NotScalarConstantError
41+
from pytensor.tensor.extra_ops import broadcast_to
4042
from pytensor.tensor.math import (
4143
add,
4244
and_,
@@ -58,6 +60,7 @@
5860
)
5961
from pytensor.tensor.shape import (
6062
shape_padleft,
63+
shape_padright,
6164
shape_tuple,
6265
)
6366
from pytensor.tensor.sharedvar import TensorSharedVariable
@@ -1580,6 +1583,8 @@ def local_blockwise_of_subtensor(fgraph, node):
15801583
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15811584
15821585
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586+
1587+
TODO: Handle batched indices like we do with blockwise of inc_subtensor
15831588
"""
15841589
if not isinstance(node.op.core_op, Subtensor):
15851590
return
@@ -1600,64 +1605,150 @@ def local_blockwise_of_subtensor(fgraph, node):
16001605
@register_stabilize("shape_unsafe")
16011606
@register_specialize("shape_unsafe")
16021607
@node_rewriter([Blockwise])
1603-
def local_blockwise_advanced_inc_subtensor(fgraph, node):
1604-
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605-
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1606-
return None
1608+
def local_blockwise_inc_subtensor(fgraph, node):
1609+
"""Rewrite blockwised inc_subtensors.
16071610
1608-
x, y, *idxs = node.inputs
1611+
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1612+
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16091613
1610-
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611-
if any(
1612-
(
1613-
isinstance(idx, SliceType | NoneTypeT)
1614-
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
1615-
)
1616-
for idx in idxs
1617-
):
1614+
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1615+
and can be safely rewritten without Blockwise.
1616+
"""
1617+
core_op = node.op.core_op
1618+
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
16181619
return None
16191620

1620-
op: Blockwise = node.op # type: ignore
1621-
batch_ndim = op.batch_ndim(node)
1622-
1623-
new_idxs = []
1624-
for idx in idxs:
1625-
if all(idx.type.broadcastable[:batch_ndim]):
1626-
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1627-
else:
1628-
# Rewrite does not apply
1621+
x, y, *idxs = node.inputs
1622+
[out] = node.outputs
1623+
if isinstance(node.op.core_op, AdvancedIncSubtensor):
1624+
if any(
1625+
(
1626+
# Blockwise requires all inputs to be tensors so it is not possible
1627+
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1628+
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1629+
# are separated by basic indices
1630+
isinstance(idx, SliceType | NoneTypeT)
1631+
# Also get out if we have boolean indices as they cross dimension boundaries
1632+
# / can't be safely broadcasted depending on their runtime content
1633+
or (idx.type.dtype == "bool")
1634+
)
1635+
for idx in idxs
1636+
):
16291637
return None
16301638

1631-
x_batch_bcast = x.type.broadcastable[:batch_ndim]
1632-
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1633-
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)):
1634-
# Need to broadcast batch x dims
1635-
batch_shape = tuple(
1636-
x_dim if (not xb or yb) else y_dim
1637-
for xb, x_dim, yb, y_dim in zip(
1638-
x_batch_bcast,
1639+
batch_ndim = node.op.batch_ndim(node)
1640+
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
1641+
max_idx_core_ndim = max(idxs_core_ndim, default=0)
1642+
1643+
# Step 1. Broadcast buffer to batch_shape
1644+
if x.type.broadcastable != out.type.broadcastable:
1645+
batch_shape = [1] * batch_ndim
1646+
for inp in node.inputs:
1647+
for i, (broadcastable, batch_dim) in enumerate(
1648+
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
1649+
):
1650+
if broadcastable:
1651+
# This dimension is broadcastable, it doesn't provide shape information
1652+
continue
1653+
if batch_shape[i] != 1:
1654+
# We already found a source of shape for this batch dimension
1655+
continue
1656+
batch_shape[i] = batch_dim
1657+
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
1658+
assert x.type.broadcastable == out.type.broadcastable
1659+
1660+
# Step 2. Massage indices so they respect blockwise semantics
1661+
if isinstance(core_op, IncSubtensor):
1662+
# For basic IncSubtensor there are two cases:
1663+
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1664+
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1665+
# in case we can end up with a basic IncSubtensor again
1666+
core_idxs = []
1667+
counter = 0
1668+
for idx in core_op.idx_list:
1669+
if isinstance(idx, slice):
1670+
# Squeeze away dummy dimensions so we can convert to slice
1671+
new_entries = [None, None, None]
1672+
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
1673+
if entry is None:
1674+
continue
1675+
else:
1676+
new_entries[i] = new_entry = idxs[counter].squeeze()
1677+
counter += 1
1678+
if new_entry.ndim > 0:
1679+
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
1680+
# We could try to convert to equivalent integer indices, but nothing guarantees
1681+
# that the slice is "square".
1682+
return None
1683+
core_idxs.append(slice(*new_entries))
1684+
else:
1685+
core_idxs.append(_squeeze_left(idxs[counter]))
1686+
counter += 1
1687+
else:
1688+
# For AdvancedIncSubtensor we have tensor integer indices,
1689+
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
1690+
# We still squeeze on the left in case that allows us to use simpler indices
1691+
core_idxs = [
1692+
_squeeze_left(
1693+
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
1694+
stop_at_dim=batch_ndim,
1695+
)
1696+
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
1697+
]
1698+
1699+
# Step 3. Create new indices for the new batch dimension of x
1700+
if not all(
1701+
all(idx.type.broadcastable[:batch_ndim])
1702+
for idx in idxs
1703+
if not isinstance(idx, slice)
1704+
):
1705+
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1706+
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
1707+
# (i.e., they broadcast)
1708+
1709+
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1710+
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1711+
# even if not all batch dimensions have corresponding batch indices.
1712+
batch_slices = [
1713+
shape_padright(arange(x_batch_shape, dtype="int64"), n)
1714+
for (x_batch_shape, n) in zip(
16391715
tuple(x.shape)[:batch_ndim],
1640-
y_batch_bcast,
1641-
tuple(y.shape)[:batch_ndim],
1642-
strict=True,
1716+
reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)),
16431717
)
1644-
)
1645-
core_shape = tuple(x.shape)[batch_ndim:]
1646-
x = alloc(x, *batch_shape, *core_shape)
1647-
1648-
new_idxs = [slice(None)] * batch_ndim + new_idxs
1649-
x_view = x[tuple(new_idxs)]
1650-
1651-
# We need to introduce any implicit expand_dims on core dimension of y
1652-
y_core_ndim = y.type.ndim - batch_ndim
1653-
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
1654-
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1655-
y = expand_dims(y, missing_axes)
1656-
1657-
symbolic_idxs = x_view.owner.inputs[1:]
1658-
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1659-
copy_stack_trace(node.outputs, new_out)
1660-
return new_out
1718+
]
1719+
else:
1720+
# In the case we don't have
1721+
batch_slices = [slice(None)] * batch_ndim
1722+
1723+
new_idxs = (*batch_slices, *core_idxs)
1724+
x_view = x[new_idxs]
1725+
1726+
# Step 4. Introduce any implicit expand_dims on core dimension of y
1727+
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
1728+
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1729+
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
1730+
1731+
if isinstance(core_op, IncSubtensor):
1732+
# Check if we can still use a basic IncSubtensor
1733+
if isinstance(x_view.owner.op, Subtensor):
1734+
new_props = core_op._props_dict()
1735+
new_props["idx_list"] = x_view.owner.op.idx_list
1736+
new_core_op = type(core_op)(**new_props)
1737+
symbolic_idxs = x_view.owner.inputs[1:]
1738+
new_out = new_core_op(x, y, *symbolic_idxs)
1739+
else:
1740+
# We need to use AdvancedSet/IncSubtensor
1741+
if core_op.set_instead_of_inc:
1742+
new_out = x[new_idxs].set(y)
1743+
else:
1744+
new_out = x[new_idxs].inc(y)
1745+
else:
1746+
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747+
symbolic_idxs = x_view.owner.inputs[1:]
1748+
new_out = core_op(x, y, *symbolic_idxs)
1749+
1750+
copy_stack_trace(out, new_out)
1751+
return [new_out]
16611752

16621753

16631754
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])

0 commit comments

Comments
 (0)