7373 IncSubtensor ,
7474 Subtensor ,
7575 advanced_inc_subtensor1 ,
76- advanced_subtensor ,
7776 advanced_subtensor1 ,
7877 as_index_constant ,
7978 get_canonical_form_slice ,
8382 inc_subtensor ,
8483 indices_from_subtensor ,
8584)
86- from pytensor .tensor .type import TensorType
85+ from pytensor .tensor .type import TensorType , integer_dtypes
8786from pytensor .tensor .type_other import NoneTypeT , SliceType
8887from pytensor .tensor .variable import TensorConstant , TensorVariable
8988
@@ -265,6 +264,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
265264 """
266265
267266 if type (node .op ) is not AdvancedIncSubtensor :
267+ # Don't apply to subclasses
268268 return
269269
270270 if node .op .ignore_duplicates :
@@ -1321,7 +1321,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13211321 if isinstance (node .op , IncSubtensor ):
13221322 xi = Subtensor (node .op .idx_list )(x , * i )
13231323 elif isinstance (node .op , AdvancedIncSubtensor ):
1324- xi = advanced_subtensor (x , * i )
1324+ # Use the same idx_list as the original operation to ensure correct shape
1325+ op = AdvancedSubtensor (node .op .idx_list )
1326+ xi = op .make_node (x , * i ).outputs [0 ]
13251327 elif isinstance (node .op , AdvancedIncSubtensor1 ):
13261328 xi = advanced_subtensor1 (x , * i )
13271329 else :
@@ -1771,10 +1773,11 @@ def local_blockwise_inc_subtensor(fgraph, node):
17711773
17721774
17731775@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
1774- def bool_idx_to_nonzero (fgraph , node ):
1775- """Convert boolean indexing into equivalent vector boolean index, supported by our dispatch
1776+ def ravel_multidimensional_bool_idx (fgraph , node ):
1777+ """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
17761778
1777- x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
1779+ x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1780+ x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
17781781 """
17791782
17801783 if isinstance (node .op , AdvancedSubtensor ):
@@ -1787,26 +1790,53 @@ def bool_idx_to_nonzero(fgraph, node):
17871790 # Reconstruct indices from idx_list and tensor inputs
17881791 idxs = indices_from_subtensor (tensor_inputs , node .op .idx_list )
17891792
1790- bool_pos = {
1791- i
1793+ if any (
1794+ (
1795+ (isinstance (idx .type , TensorType ) and idx .type .dtype in integer_dtypes )
1796+ or isinstance (idx .type , NoneTypeT )
1797+ )
1798+ for idx in idxs
1799+ ):
1800+ # Get out if there are any other advanced indexes or np.newaxis
1801+ return None
1802+
1803+ bool_idxs = [
1804+ (i , idx )
17921805 for i , idx in enumerate (idxs )
17931806 if (isinstance (idx .type , TensorType ) and idx .dtype == "bool" )
1794- }
1807+ ]
17951808
1796- if not bool_pos :
1809+ if len (bool_idxs ) != 1 :
1810+ # Get out if there are no or multiple boolean idxs
1811+ return None
1812+ [(bool_idx_pos , bool_idx )] = bool_idxs
1813+ bool_idx_ndim = bool_idx .type .ndim
1814+ if bool_idx .type .ndim < 2 :
1815+ # No need to do anything if it's a vector or scalar, as it's already supported by Numba
17971816 return None
17981817
1799- new_idxs = []
1800- for i , idx in enumerate (idxs ):
1801- if i in bool_pos :
1802- new_idxs .extend (idx .nonzero ())
1803- else :
1804- new_idxs .append (idx )
1818+ x_shape = x .shape
1819+ raveled_x = x .reshape (
1820+ (* x_shape [:bool_idx_pos ], - 1 , * x_shape [bool_idx_pos + bool_idx_ndim :])
1821+ )
1822+
1823+ raveled_bool_idx = bool_idx .ravel ()
1824+ new_idxs = list (idxs )
1825+ new_idxs [bool_idx_pos ] = raveled_bool_idx
18051826
18061827 if isinstance (node .op , AdvancedSubtensor ):
1807- new_out = node . op ( x , * new_idxs )
1828+ new_out = raveled_x [ tuple ( new_idxs )]
18081829 else :
1809- new_out = node .op (x , y , * new_idxs )
1830+ sub = raveled_x [tuple (new_idxs )]
1831+ new_out = inc_subtensor (
1832+ sub ,
1833+ y ,
1834+ set_instead_of_inc = node .op .set_instead_of_inc ,
1835+ ignore_duplicates = node .op .ignore_duplicates ,
1836+ inplace = node .op .inplace ,
1837+ )
1838+ new_out = new_out .reshape (x_shape )
1839+
18101840 return [copy_stack_trace (node .outputs [0 ], new_out )]
18111841
18121842
@@ -1941,10 +1971,16 @@ def ravel_multidimensional_int_idx(fgraph, node):
19411971
19421972
19431973optdb ["specialize" ].register (
1944- bool_idx_to_nonzero .__name__ ,
1945- bool_idx_to_nonzero ,
1974+ ravel_multidimensional_bool_idx .__name__ ,
1975+ ravel_multidimensional_bool_idx ,
1976+ "numba" ,
1977+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
1978+ )
1979+
1980+ optdb ["specialize" ].register (
1981+ ravel_multidimensional_int_idx .__name__ ,
1982+ ravel_multidimensional_int_idx ,
19461983 "numba" ,
1947- "shape_unsafe" , # It can mask invalid mask sizes
19481984 use_db_name_as_tag = False , # Not included if only "specialize" is requested
19491985)
19501986
0 commit comments