@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
20292029 return [copy_stack_trace (node .outputs [0 ], new_out )]
20302030
20312031
2032- @node_rewriter (tracks = [AdvancedSubtensor ])
2032+ @node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
20332033def ravel_multidimensional_int_idx (fgraph , node ):
2034- """Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
2035-
2036- x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
2034+ """Convert multidimensional integer indexing into equivalent consecutive vector integer index,
2035+ supported by Numba or by our specialized dispatchers
20372036
2037+ x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
20382038
20392039 NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
20402040
2041- x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2041+ x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2042+
2043+ It also handles multiple integer indices, but only if they don't broadcast
2044+
2045+ x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2046+
2047+ Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
2048+
2049+ x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
2050+
20422051 """
2043- x , * idxs = node .inputs
2052+ op = node .op
2053+ non_consecutive_adv_indexing = op .non_consecutive_adv_indexing (node )
2054+ is_inc_subtensor = isinstance (op , AdvancedIncSubtensor )
2055+
2056+ if is_inc_subtensor :
2057+ x , y , * idxs = node .inputs
2058+ # Inc/SetSubtensor is harder to reason about due to y
2059+ # We get out if it's broadcasting or if the advanced indices are non-consecutive
2060+ if non_consecutive_adv_indexing or (
2061+ y .type .broadcastable != x [tuple (idxs )].type .broadcastable
2062+ ):
2063+ return None
2064+
2065+ else :
2066+ x , * idxs = node .inputs
20442067
20452068 if any (
20462069 (
@@ -2049,50 +2072,103 @@ def ravel_multidimensional_int_idx(fgraph, node):
20492072 )
20502073 for idx in idxs
20512074 ):
2052- # Get out if there are any other advanced indexes or np.newaxis
2075+ # Get out if there are any other advanced indices or np.newaxis
20532076 return None
20542077
2055- int_idxs = [
2078+ int_idxs_and_pos = [
20562079 (i , idx )
20572080 for i , idx in enumerate (idxs )
20582081 if (isinstance (idx .type , TensorType ) and idx .dtype in integer_dtypes )
20592082 ]
20602083
2061- if len (int_idxs ) != 1 :
2062- # Get out if there are no or multiple integer idxs
2084+ if not int_idxs_and_pos :
20632085 return None
20642086
2065- [(int_idx_pos , int_idx )] = int_idxs
2066- if int_idx .type .ndim < 2 :
2067- # No need to do anything if it's a vector or scalar, as it's already supported by Numba
2087+ int_idxs_pos , int_idxs = zip (
2088+ * int_idxs_and_pos , strict = False
2089+ ) # strict=False because by definition it's true
2090+
2091+ first_int_idx_pos = int_idxs_pos [0 ]
2092+ first_int_idx = int_idxs [0 ]
2093+ first_int_idx_bcast = first_int_idx .type .broadcastable
2094+
2095+ if any (int_idx .type .broadcastable != first_int_idx_bcast for int_idx in int_idxs ):
2096+ # We don't have a view-only broadcasting operation
2097+ # Explicitly broadcasting the indices can incur a memory / copy overhead
20682098 return None
20692099
2070- raveled_int_idx = int_idx .ravel ()
2071- new_idxs = list (idxs )
2072- new_idxs [int_idx_pos ] = raveled_int_idx
2073- raveled_subtensor = x [tuple (new_idxs )]
2074-
2075- # Reshape into correct shape
2076- # Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2077- # must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2078- raveled_shape = raveled_subtensor .shape
2079- unraveled_shape = (
2080- * raveled_shape [:int_idx_pos ],
2081- * int_idx .shape ,
2082- * raveled_shape [int_idx_pos + 1 :],
2083- )
2084- new_out = raveled_subtensor .reshape (unraveled_shape )
2100+ int_idxs_ndim = len (first_int_idx_bcast )
2101+ if (
2102+ int_idxs_ndim == 0
2103+ ): # This should be a basic indexing operation, rewrite elsewhere
2104+ return None
2105+
2106+ int_idxs_need_raveling = int_idxs_ndim > 1
2107+ if not (int_idxs_need_raveling or non_consecutive_adv_indexing ):
2108+ # Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
2109+ return None
2110+
2111+ # Reorder non-consecutive indices
2112+ if non_consecutive_adv_indexing :
2113+ assert not is_inc_subtensor # Sanity check that we got out if this was the case
2114+ # This case works as if all the advanced indices were on the front
2115+ transposition = list (int_idxs_pos ) + [
2116+ i for i in range (len (idxs )) if i not in int_idxs_pos
2117+ ]
2118+ idxs = tuple (idxs [a ] for a in transposition )
2119+ x = x .transpose (transposition )
2120+ first_int_idx_pos = 0
2121+ del int_idxs_pos # Make sure they are not wrongly used
2122+
2123+ # Ravel multidimensional indices
2124+ if int_idxs_need_raveling :
2125+ idxs = list (idxs )
2126+ for idx_pos , int_idx in enumerate (int_idxs , start = first_int_idx_pos ):
2127+ idxs [idx_pos ] = int_idx .ravel ()
2128+
2129+ # Index with reordered and/or raveled indices
2130+ new_subtensor = x [tuple (idxs )]
2131+
2132+ if is_inc_subtensor :
2133+ y_shape = tuple (y .shape )
2134+ y_raveled_shape = (
2135+ * y_shape [:first_int_idx_pos ],
2136+ - 1 ,
2137+ * y_shape [first_int_idx_pos + int_idxs_ndim :],
2138+ )
2139+ y_raveled = y .reshape (y_raveled_shape )
2140+
2141+ new_out = inc_subtensor (
2142+ new_subtensor ,
2143+ y_raveled ,
2144+ set_instead_of_inc = op .set_instead_of_inc ,
2145+ ignore_duplicates = op .ignore_duplicates ,
2146+ inplace = op .inplace ,
2147+ )
2148+
2149+ else :
2150+ # Unravel advanced indexing dimensions
2151+ raveled_shape = tuple (new_subtensor .shape )
2152+ unraveled_shape = (
2153+ * raveled_shape [:first_int_idx_pos ],
2154+ * first_int_idx .shape ,
2155+ * raveled_shape [first_int_idx_pos + 1 :],
2156+ )
2157+ new_out = new_subtensor .reshape (unraveled_shape )
2158+
20852159 return [copy_stack_trace (node .outputs [0 ], new_out )]
20862160
20872161
20882162optdb ["specialize" ].register (
20892163 ravel_multidimensional_bool_idx .__name__ ,
20902164 ravel_multidimensional_bool_idx ,
20912165 "numba" ,
2166+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
20922167)
20932168
20942169optdb ["specialize" ].register (
20952170 ravel_multidimensional_int_idx .__name__ ,
20962171 ravel_multidimensional_int_idx ,
20972172 "numba" ,
2173+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
20982174)
0 commit comments