24
24
ScalarFromTensor ,
25
25
TensorFromScalar ,
26
26
alloc ,
27
+ arange ,
27
28
cast ,
28
29
concatenate ,
29
30
expand_dims ,
34
35
switch ,
35
36
)
36
37
from pytensor .tensor .basic import constant as tensor_constant
37
- from pytensor .tensor .blockwise import Blockwise
38
+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
38
39
from pytensor .tensor .elemwise import Elemwise
39
40
from pytensor .tensor .exceptions import NotScalarConstantError
41
+ from pytensor .tensor .extra_ops import broadcast_to
40
42
from pytensor .tensor .math import (
41
43
add ,
42
44
and_ ,
58
60
)
59
61
from pytensor .tensor .shape import (
60
62
shape_padleft ,
63
+ shape_padright ,
61
64
shape_tuple ,
62
65
)
63
66
from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1580,6 +1583,8 @@ def local_blockwise_of_subtensor(fgraph, node):
1580
1583
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
1581
1584
1582
1585
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586
+
1587
+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
1583
1588
"""
1584
1589
if not isinstance (node .op .core_op , Subtensor ):
1585
1590
return
@@ -1600,64 +1605,150 @@ def local_blockwise_of_subtensor(fgraph, node):
1600
1605
@register_stabilize ("shape_unsafe" )
1601
1606
@register_specialize ("shape_unsafe" )
1602
1607
@node_rewriter ([Blockwise ])
1603
- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1604
- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605
- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1606
- return None
1608
+ def local_blockwise_inc_subtensor (fgraph , node ):
1609
+ """Rewrite blockwised inc_subtensors.
1607
1610
1608
- x , y , * idxs = node .inputs
1611
+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1612
+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
1609
1613
1610
- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611
- if any (
1612
- (
1613
- isinstance (idx , SliceType | NoneTypeT )
1614
- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1615
- )
1616
- for idx in idxs
1617
- ):
1614
+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1615
+ and can be safely rewritten without Blockwise.
1616
+ """
1617
+ core_op = node .op .core_op
1618
+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
1618
1619
return None
1619
1620
1620
- op : Blockwise = node .op # type: ignore
1621
- batch_ndim = op .batch_ndim (node )
1622
-
1623
- new_idxs = []
1624
- for idx in idxs :
1625
- if all (idx .type .broadcastable [:batch_ndim ]):
1626
- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1627
- else :
1628
- # Rewrite does not apply
1621
+ x , y , * idxs = node .inputs
1622
+ [out ] = node .outputs
1623
+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1624
+ if any (
1625
+ (
1626
+ # Blockwise requires all inputs to be tensors so it is not possible
1627
+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1628
+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1629
+ # are separated by basic indices
1630
+ isinstance (idx , SliceType | NoneTypeT )
1631
+ # Also get out if we have boolean indices as they cross dimension boundaries
1632
+ # / can't be safely broadcasted depending on their runtime content
1633
+ or (idx .type .dtype == "bool" )
1634
+ )
1635
+ for idx in idxs
1636
+ ):
1629
1637
return None
1630
1638
1631
- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1632
- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1633
- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1634
- # Need to broadcast batch x dims
1635
- batch_shape = tuple (
1636
- x_dim if (not xb or yb ) else y_dim
1637
- for xb , x_dim , yb , y_dim in zip (
1638
- x_batch_bcast ,
1639
+ batch_ndim = node .op .batch_ndim (node )
1640
+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1641
+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1642
+
1643
+ # Step 1. Broadcast buffer to batch_shape
1644
+ if x .type .broadcastable != out .type .broadcastable :
1645
+ batch_shape = [1 ] * batch_ndim
1646
+ for inp in node .inputs :
1647
+ for i , (broadcastable , batch_dim ) in enumerate (
1648
+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1649
+ ):
1650
+ if broadcastable :
1651
+ # This dimension is broadcastable, it doesn't provide shape information
1652
+ continue
1653
+ if batch_shape [i ] != 1 :
1654
+ # We already found a source of shape for this batch dimension
1655
+ continue
1656
+ batch_shape [i ] = batch_dim
1657
+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1658
+ assert x .type .broadcastable == out .type .broadcastable
1659
+
1660
+ # Step 2. Massage indices so they respect blockwise semantics
1661
+ if isinstance (core_op , IncSubtensor ):
1662
+ # For basic IncSubtensor there are two cases:
1663
+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1664
+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1665
+ # in case we can end up with a basic IncSubtensor again
1666
+ core_idxs = []
1667
+ counter = 0
1668
+ for idx in core_op .idx_list :
1669
+ if isinstance (idx , slice ):
1670
+ # Squeeze away dummy dimensions so we can convert to slice
1671
+ new_entries = [None , None , None ]
1672
+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1673
+ if entry is None :
1674
+ continue
1675
+ else :
1676
+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1677
+ counter += 1
1678
+ if new_entry .ndim > 0 :
1679
+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1680
+ # We could try to convert to equivalent integer indices, but nothing guarantees
1681
+ # that the slice is "square".
1682
+ return None
1683
+ core_idxs .append (slice (* new_entries ))
1684
+ else :
1685
+ core_idxs .append (_squeeze_left (idxs [counter ]))
1686
+ counter += 1
1687
+ else :
1688
+ # For AdvancedIncSubtensor we have tensor integer indices,
1689
+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1690
+ # We still squeeze on the left in case that allows us to use simpler indices
1691
+ core_idxs = [
1692
+ _squeeze_left (
1693
+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1694
+ stop_at_dim = batch_ndim ,
1695
+ )
1696
+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1697
+ ]
1698
+
1699
+ # Step 3. Create new indices for the new batch dimension of x
1700
+ if not all (
1701
+ all (idx .type .broadcastable [:batch_ndim ])
1702
+ for idx in idxs
1703
+ if not isinstance (idx , slice )
1704
+ ):
1705
+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1706
+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1707
+ # (i.e., they broadcast)
1708
+
1709
+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1710
+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1711
+ # even if not all batch dimensions have corresponding batch indices.
1712
+ batch_slices = [
1713
+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1714
+ for (x_batch_shape , n ) in zip (
1639
1715
tuple (x .shape )[:batch_ndim ],
1640
- y_batch_bcast ,
1641
- tuple (y .shape )[:batch_ndim ],
1642
- strict = True ,
1716
+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
1643
1717
)
1644
- )
1645
- core_shape = tuple (x .shape )[batch_ndim :]
1646
- x = alloc (x , * batch_shape , * core_shape )
1647
-
1648
- new_idxs = [slice (None )] * batch_ndim + new_idxs
1649
- x_view = x [tuple (new_idxs )]
1650
-
1651
- # We need to introduce any implicit expand_dims on core dimension of y
1652
- y_core_ndim = y .type .ndim - batch_ndim
1653
- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1654
- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1655
- y = expand_dims (y , missing_axes )
1656
-
1657
- symbolic_idxs = x_view .owner .inputs [1 :]
1658
- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1659
- copy_stack_trace (node .outputs , new_out )
1660
- return new_out
1718
+ ]
1719
+ else :
1720
+ # In the case we don't have
1721
+ batch_slices = [slice (None )] * batch_ndim
1722
+
1723
+ new_idxs = (* batch_slices , * core_idxs )
1724
+ x_view = x [new_idxs ]
1725
+
1726
+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1727
+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1728
+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1729
+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1730
+
1731
+ if isinstance (core_op , IncSubtensor ):
1732
+ # Check if we can still use a basic IncSubtensor
1733
+ if isinstance (x_view .owner .op , Subtensor ):
1734
+ new_props = core_op ._props_dict ()
1735
+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1736
+ new_core_op = type (core_op )(** new_props )
1737
+ symbolic_idxs = x_view .owner .inputs [1 :]
1738
+ new_out = new_core_op (x , y , * symbolic_idxs )
1739
+ else :
1740
+ # We need to use AdvancedSet/IncSubtensor
1741
+ if core_op .set_instead_of_inc :
1742
+ new_out = x [new_idxs ].set (y )
1743
+ else :
1744
+ new_out = x [new_idxs ].inc (y )
1745
+ else :
1746
+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1747
+ symbolic_idxs = x_view .owner .inputs [1 :]
1748
+ new_out = core_op (x , y , * symbolic_idxs )
1749
+
1750
+ copy_stack_trace (out , new_out )
1751
+ return [new_out ]
1661
1752
1662
1753
1663
1754
@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments