24
24
ScalarFromTensor ,
25
25
TensorFromScalar ,
26
26
alloc ,
27
+ arange ,
27
28
cast ,
28
29
concatenate ,
29
30
expand_dims ,
34
35
switch ,
35
36
)
36
37
from pytensor .tensor .basic import constant as tensor_constant
37
- from pytensor .tensor .blockwise import Blockwise
38
+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
38
39
from pytensor .tensor .elemwise import Elemwise
39
40
from pytensor .tensor .exceptions import NotScalarConstantError
41
+ from pytensor .tensor .extra_ops import broadcast_to
40
42
from pytensor .tensor .math import (
41
43
add ,
42
44
and_ ,
58
60
)
59
61
from pytensor .tensor .shape import (
60
62
shape_padleft ,
63
+ shape_padright ,
61
64
shape_tuple ,
62
65
)
63
66
from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1578,6 +1581,9 @@ def local_blockwise_of_subtensor(fgraph, node):
1578
1581
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
1579
1582
1580
1583
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1584
+
1585
+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
1586
+ TODO: Extend to AdvanceSubtensor
1581
1587
"""
1582
1588
if not isinstance (node .op .core_op , Subtensor ):
1583
1589
return
@@ -1598,64 +1604,151 @@ def local_blockwise_of_subtensor(fgraph, node):
1598
1604
@register_stabilize ("shape_unsafe" )
1599
1605
@register_specialize ("shape_unsafe" )
1600
1606
@node_rewriter ([Blockwise ])
1601
- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1602
- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1603
- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1604
- return None
1607
+ def local_blockwise_inc_subtensor (fgraph , node ):
1608
+ """Rewrite blockwised inc_subtensors.
1605
1609
1606
- x , y , * idxs = node .inputs
1610
+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1611
+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
1607
1612
1608
- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1609
- if any (
1610
- (
1611
- isinstance (idx , SliceType | NoneTypeT )
1612
- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1613
- )
1614
- for idx in idxs
1615
- ):
1613
+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1614
+ and can be safely rewritten without Blockwise.
1615
+ """
1616
+ core_op = node .op .core_op
1617
+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
1616
1618
return None
1617
1619
1618
- op : Blockwise = node .op # type: ignore
1619
- batch_ndim = op .batch_ndim (node )
1620
-
1621
- new_idxs = []
1622
- for idx in idxs :
1623
- if all (idx .type .broadcastable [:batch_ndim ]):
1624
- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1625
- else :
1626
- # Rewrite does not apply
1620
+ x , y , * idxs = node .inputs
1621
+ [out ] = node .outputs
1622
+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1623
+ if any (
1624
+ (
1625
+ # Blockwise requires all inputs to be tensors so it is not possible
1626
+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1627
+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1628
+ # are separated by basic indices
1629
+ isinstance (idx , SliceType | NoneTypeT )
1630
+ # Also get out if we have boolean indices as they cross dimension boundaries
1631
+ # / can't be safely broadcasted depending on their runtime content
1632
+ or (idx .type .dtype == "bool" )
1633
+ )
1634
+ for idx in idxs
1635
+ ):
1627
1636
return None
1628
1637
1629
- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1630
- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1631
- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1632
- # Need to broadcast batch x dims
1633
- batch_shape = tuple (
1634
- x_dim if (not xb or yb ) else y_dim
1635
- for xb , x_dim , yb , y_dim in zip (
1636
- x_batch_bcast ,
1638
+ batch_ndim = node .op .batch_ndim (node )
1639
+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1640
+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1641
+
1642
+ # Step 1. Broadcast buffer to batch_shape
1643
+ if x .type .broadcastable != out .type .broadcastable :
1644
+ batch_shape = [1 ] * batch_ndim
1645
+ for inp in node .inputs :
1646
+ for i , (broadcastable , batch_dim ) in enumerate (
1647
+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1648
+ ):
1649
+ if broadcastable :
1650
+ # This dimension is broadcastable, it doesn't provide shape information
1651
+ continue
1652
+ if batch_shape [i ] != 1 :
1653
+ # We already found a source of shape for this batch dimension
1654
+ continue
1655
+ batch_shape [i ] = batch_dim
1656
+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1657
+ assert x .type .broadcastable == out .type .broadcastable
1658
+
1659
+ # Step 2. Massage indices so they respect blockwise semantics
1660
+ if isinstance (core_op , IncSubtensor ):
1661
+ # For basic IncSubtensor there are two cases:
1662
+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1663
+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1664
+ # in case we can end up with a basic IncSubtensor again
1665
+ core_idxs = []
1666
+ counter = 0
1667
+ for idx in core_op .idx_list :
1668
+ if isinstance (idx , slice ):
1669
+ # Squeeze away dummy dimensions so we can convert to slice
1670
+ new_entries = [None , None , None ]
1671
+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1672
+ if entry is None :
1673
+ continue
1674
+ else :
1675
+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1676
+ counter += 1
1677
+ if new_entry .ndim > 0 :
1678
+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1679
+ # We could try to convert to equivalent integer indices, but nothing guarantees
1680
+ # that the slice is "square".
1681
+ return None
1682
+ core_idxs .append (slice (* new_entries ))
1683
+ else :
1684
+ core_idxs .append (_squeeze_left (idxs [counter ]))
1685
+ counter += 1
1686
+ else :
1687
+ # For AdvancedIncSubtensor we have tensor integer indices,
1688
+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1689
+ # We still squeeze on the left in case that allows us to use simpler indices
1690
+ core_idxs = [
1691
+ _squeeze_left (
1692
+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1693
+ stop_at_dim = batch_ndim ,
1694
+ )
1695
+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1696
+ ]
1697
+
1698
+ # Step 3. Create new indices for the new batch dimension of x
1699
+ if not all (
1700
+ all (idx .type .broadcastable [:batch_ndim ])
1701
+ for idx in idxs
1702
+ if not isinstance (idx , slice )
1703
+ ):
1704
+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1705
+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1706
+ # (i.e., they broadcast)
1707
+
1708
+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1709
+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1710
+ # even if not all batch dimensions have corresponding batch indices.
1711
+ batch_slices = [
1712
+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1713
+ for (x_batch_shape , n ) in zip (
1637
1714
tuple (x .shape )[:batch_ndim ],
1638
- y_batch_bcast ,
1639
- tuple (y .shape )[:batch_ndim ],
1640
- strict = True ,
1715
+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
1641
1716
)
1642
- )
1643
- core_shape = tuple (x .shape )[batch_ndim :]
1644
- x = alloc (x , * batch_shape , * core_shape )
1645
-
1646
- new_idxs = [slice (None )] * batch_ndim + new_idxs
1647
- x_view = x [tuple (new_idxs )]
1648
-
1649
- # We need to introduce any implicit expand_dims on core dimension of y
1650
- y_core_ndim = y .type .ndim - batch_ndim
1651
- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1652
- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1653
- y = expand_dims (y , missing_axes )
1654
-
1655
- symbolic_idxs = x_view .owner .inputs [1 :]
1656
- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1657
- copy_stack_trace (node .outputs , new_out )
1658
- return new_out
1717
+ ]
1718
+ else :
1719
+ # In the case we don't have batch indices,
1720
+ # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1721
+ batch_slices = [slice (None )] * batch_ndim
1722
+
1723
+ new_idxs = (* batch_slices , * core_idxs )
1724
+ x_view = x [new_idxs ]
1725
+
1726
+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1727
+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1728
+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1729
+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1730
+
1731
+ if isinstance (core_op , IncSubtensor ):
1732
+ # Check if we can still use a basic IncSubtensor
1733
+ if isinstance (x_view .owner .op , Subtensor ):
1734
+ new_props = core_op ._props_dict ()
1735
+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1736
+ new_core_op = type (core_op )(** new_props )
1737
+ symbolic_idxs = x_view .owner .inputs [1 :]
1738
+ new_out = new_core_op (x , y , * symbolic_idxs )
1739
+ else :
1740
+ # We need to use AdvancedSet/IncSubtensor
1741
+ if core_op .set_instead_of_inc :
1742
+ new_out = x [new_idxs ].set (y )
1743
+ else :
1744
+ new_out = x [new_idxs ].inc (y )
1745
+ else :
1746
+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747
+ symbolic_idxs = x_view .owner .inputs [1 :]
1748
+ new_out = core_op (x , y , * symbolic_idxs )
1749
+
1750
+ copy_stack_trace (out , new_out )
1751
+ return [new_out ]
1659
1752
1660
1753
1661
1754
@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments