Skip to content

Commit 55d39ed

Browse files
committed
Move subtensor blockwise rewrite
1 parent 5f8ddb3 commit 55d39ed

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
AdvancedIncSubtensor,
1818
AdvancedSubtensor,
1919
Subtensor,
20-
indices_from_subtensor,
2120
)
2221

2322

@@ -229,29 +228,6 @@ def local_blockwise_reshape(fgraph, node):
229228
return [new_out]
230229

231230

232-
@register_stabilize
233-
@register_specialize
234-
@node_rewriter([Blockwise])
235-
def local_blockwise_of_subtensor(fgraph, node):
236-
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
237-
238-
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
239-
"""
240-
if not isinstance(node.op.core_op, Subtensor):
241-
return
242-
243-
x, *idxs = node.inputs
244-
if not all(all(idx.type.broadcastable) for idx in idxs):
245-
return
246-
247-
core_idxs = indices_from_subtensor(
248-
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
249-
)
250-
# Add empty slices for the batch dims
251-
none_slices = (slice(None),) * node.op.batch_ndim(node)
252-
return [x[(*none_slices, *core_idxs)]]
253-
254-
255231
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
256232
op = Blockwise
257233

pytensor/tensor/rewriting/subtensor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,29 @@ def local_uint_constant_indices(fgraph, node):
15731573
)
15741574

15751575

1576+
@register_stabilize
1577+
@register_specialize
1578+
@node_rewriter([Blockwise])
1579+
def local_blockwise_of_subtensor(fgraph, node):
1580+
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
1581+
1582+
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1583+
"""
1584+
if not isinstance(node.op.core_op, Subtensor):
1585+
return
1586+
1587+
x, *idxs = node.inputs
1588+
if not all(all(idx.type.broadcastable) for idx in idxs):
1589+
return
1590+
1591+
core_idxs = indices_from_subtensor(
1592+
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
1593+
)
1594+
# Add empty slices for the batch dims
1595+
none_slices = (slice(None),) * node.op.batch_ndim(node)
1596+
return [x[(*none_slices, *core_idxs)]]
1597+
1598+
15761599
@register_canonicalize("shape_unsafe")
15771600
@register_stabilize("shape_unsafe")
15781601
@register_specialize("shape_unsafe")

0 commit comments

Comments
 (0)