Skip to content

Commit f9a3234

Browse files
committed
Rewrite Blockwise IncSubtensor
Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before
1 parent 5046519 commit f9a3234

File tree

3 files changed

+333
-138
lines changed

3 files changed

+333
-138
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ScalarFromTensor,
2525
TensorFromScalar,
2626
alloc,
27+
arange,
2728
cast,
2829
concatenate,
2930
expand_dims,
@@ -34,9 +35,10 @@
3435
switch,
3536
)
3637
from pytensor.tensor.basic import constant as tensor_constant
37-
from pytensor.tensor.blockwise import Blockwise
38+
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
3839
from pytensor.tensor.elemwise import Elemwise
3940
from pytensor.tensor.exceptions import NotScalarConstantError
41+
from pytensor.tensor.extra_ops import broadcast_to
4042
from pytensor.tensor.math import (
4143
add,
4244
and_,
@@ -58,6 +60,7 @@
5860
)
5961
from pytensor.tensor.shape import (
6062
shape_padleft,
63+
shape_padright,
6164
shape_tuple,
6265
)
6366
from pytensor.tensor.sharedvar import TensorSharedVariable
@@ -1578,6 +1581,9 @@ def local_blockwise_of_subtensor(fgraph, node):
15781581
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15791582
15801583
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1584+
1585+
TODO: Handle batched indices like we do with blockwise of inc_subtensor
1586+
TODO: Extend to AdvanceSubtensor
15811587
"""
15821588
if not isinstance(node.op.core_op, Subtensor):
15831589
return
@@ -1598,64 +1604,151 @@ def local_blockwise_of_subtensor(fgraph, node):
15981604
@register_stabilize("shape_unsafe")
15991605
@register_specialize("shape_unsafe")
16001606
@node_rewriter([Blockwise])
1601-
def local_blockwise_advanced_inc_subtensor(fgraph, node):
1602-
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1603-
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1604-
return None
1607+
def local_blockwise_inc_subtensor(fgraph, node):
1608+
"""Rewrite blockwised inc_subtensors.
16051609
1606-
x, y, *idxs = node.inputs
1610+
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1611+
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16071612
1608-
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1609-
if any(
1610-
(
1611-
isinstance(idx, SliceType | NoneTypeT)
1612-
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
1613-
)
1614-
for idx in idxs
1615-
):
1613+
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1614+
and can be safely rewritten without Blockwise.
1615+
"""
1616+
core_op = node.op.core_op
1617+
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
16161618
return None
16171619

1618-
op: Blockwise = node.op # type: ignore
1619-
batch_ndim = op.batch_ndim(node)
1620-
1621-
new_idxs = []
1622-
for idx in idxs:
1623-
if all(idx.type.broadcastable[:batch_ndim]):
1624-
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1625-
else:
1626-
# Rewrite does not apply
1620+
x, y, *idxs = node.inputs
1621+
[out] = node.outputs
1622+
if isinstance(node.op.core_op, AdvancedIncSubtensor):
1623+
if any(
1624+
(
1625+
# Blockwise requires all inputs to be tensors so it is not possible
1626+
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1627+
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1628+
# are separated by basic indices
1629+
isinstance(idx, SliceType | NoneTypeT)
1630+
# Also get out if we have boolean indices as they cross dimension boundaries
1631+
# / can't be safely broadcasted depending on their runtime content
1632+
or (idx.type.dtype == "bool")
1633+
)
1634+
for idx in idxs
1635+
):
16271636
return None
16281637

1629-
x_batch_bcast = x.type.broadcastable[:batch_ndim]
1630-
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1631-
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)):
1632-
# Need to broadcast batch x dims
1633-
batch_shape = tuple(
1634-
x_dim if (not xb or yb) else y_dim
1635-
for xb, x_dim, yb, y_dim in zip(
1636-
x_batch_bcast,
1638+
batch_ndim = node.op.batch_ndim(node)
1639+
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
1640+
max_idx_core_ndim = max(idxs_core_ndim, default=0)
1641+
1642+
# Step 1. Broadcast buffer to batch_shape
1643+
if x.type.broadcastable != out.type.broadcastable:
1644+
batch_shape = [1] * batch_ndim
1645+
for inp in node.inputs:
1646+
for i, (broadcastable, batch_dim) in enumerate(
1647+
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
1648+
):
1649+
if broadcastable:
1650+
# This dimension is broadcastable, it doesn't provide shape information
1651+
continue
1652+
if batch_shape[i] != 1:
1653+
# We already found a source of shape for this batch dimension
1654+
continue
1655+
batch_shape[i] = batch_dim
1656+
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
1657+
assert x.type.broadcastable == out.type.broadcastable
1658+
1659+
# Step 2. Massage indices so they respect blockwise semantics
1660+
if isinstance(core_op, IncSubtensor):
1661+
# For basic IncSubtensor there are two cases:
1662+
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1663+
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1664+
# in case we can end up with a basic IncSubtensor again
1665+
core_idxs = []
1666+
counter = 0
1667+
for idx in core_op.idx_list:
1668+
if isinstance(idx, slice):
1669+
# Squeeze away dummy dimensions so we can convert to slice
1670+
new_entries = [None, None, None]
1671+
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
1672+
if entry is None:
1673+
continue
1674+
else:
1675+
new_entries[i] = new_entry = idxs[counter].squeeze()
1676+
counter += 1
1677+
if new_entry.ndim > 0:
1678+
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
1679+
# We could try to convert to equivalent integer indices, but nothing guarantees
1680+
# that the slice is "square".
1681+
return None
1682+
core_idxs.append(slice(*new_entries))
1683+
else:
1684+
core_idxs.append(_squeeze_left(idxs[counter]))
1685+
counter += 1
1686+
else:
1687+
# For AdvancedIncSubtensor we have tensor integer indices,
1688+
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
1689+
# We still squeeze on the left in case that allows us to use simpler indices
1690+
core_idxs = [
1691+
_squeeze_left(
1692+
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
1693+
stop_at_dim=batch_ndim,
1694+
)
1695+
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
1696+
]
1697+
1698+
# Step 3. Create new indices for the new batch dimension of x
1699+
if not all(
1700+
all(idx.type.broadcastable[:batch_ndim])
1701+
for idx in idxs
1702+
if not isinstance(idx, slice)
1703+
):
1704+
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1705+
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
1706+
# (i.e., they broadcast)
1707+
1708+
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1709+
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1710+
# even if not all batch dimensions have corresponding batch indices.
1711+
batch_slices = [
1712+
shape_padright(arange(x_batch_shape, dtype="int64"), n)
1713+
for (x_batch_shape, n) in zip(
16371714
tuple(x.shape)[:batch_ndim],
1638-
y_batch_bcast,
1639-
tuple(y.shape)[:batch_ndim],
1640-
strict=True,
1715+
reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)),
16411716
)
1642-
)
1643-
core_shape = tuple(x.shape)[batch_ndim:]
1644-
x = alloc(x, *batch_shape, *core_shape)
1645-
1646-
new_idxs = [slice(None)] * batch_ndim + new_idxs
1647-
x_view = x[tuple(new_idxs)]
1648-
1649-
# We need to introduce any implicit expand_dims on core dimension of y
1650-
y_core_ndim = y.type.ndim - batch_ndim
1651-
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
1652-
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1653-
y = expand_dims(y, missing_axes)
1654-
1655-
symbolic_idxs = x_view.owner.inputs[1:]
1656-
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1657-
copy_stack_trace(node.outputs, new_out)
1658-
return new_out
1717+
]
1718+
else:
1719+
# In the case we don't have batch indices,
1720+
# we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1721+
batch_slices = [slice(None)] * batch_ndim
1722+
1723+
new_idxs = (*batch_slices, *core_idxs)
1724+
x_view = x[new_idxs]
1725+
1726+
# Step 4. Introduce any implicit expand_dims on core dimension of y
1727+
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
1728+
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1729+
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
1730+
1731+
if isinstance(core_op, IncSubtensor):
1732+
# Check if we can still use a basic IncSubtensor
1733+
if isinstance(x_view.owner.op, Subtensor):
1734+
new_props = core_op._props_dict()
1735+
new_props["idx_list"] = x_view.owner.op.idx_list
1736+
new_core_op = type(core_op)(**new_props)
1737+
symbolic_idxs = x_view.owner.inputs[1:]
1738+
new_out = new_core_op(x, y, *symbolic_idxs)
1739+
else:
1740+
# We need to use AdvancedSet/IncSubtensor
1741+
if core_op.set_instead_of_inc:
1742+
new_out = x[new_idxs].set(y)
1743+
else:
1744+
new_out = x[new_idxs].inc(y)
1745+
else:
1746+
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747+
symbolic_idxs = x_view.owner.inputs[1:]
1748+
new_out = core_op(x, y, *symbolic_idxs)
1749+
1750+
copy_stack_trace(out, new_out)
1751+
return [new_out]
16591752

16601753

16611754
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])

pytensor/tensor/subtensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,6 @@ def _process(self, idxs, op_inputs, pstate):
14171417
pprint.assign(Subtensor, SubtensorPrinter())
14181418

14191419

1420-
# TODO: Implement similar vectorize for Inc/SetSubtensor
14211420
@_vectorize_node.register(Subtensor)
14221421
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
14231422
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""

0 commit comments

Comments
 (0)