55from pytensor .link .numba .dispatch .basic import generate_fallback_impl , numba_njit
66from pytensor .link .utils import compile_function_src , unique_name_generator
77from pytensor .tensor import TensorType
8+ from pytensor .tensor .rewriting .subtensor import is_full_slice
89from pytensor .tensor .subtensor import (
910 AdvancedIncSubtensor ,
1011 AdvancedIncSubtensor1 ,
1314 IncSubtensor ,
1415 Subtensor ,
1516)
17+ from pytensor .tensor .type_other import NoneTypeT , SliceType
1618
1719
1820@numba_funcify .register (Subtensor )
@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}):
104106@numba_funcify .register (AdvancedSubtensor )
105107@numba_funcify .register (AdvancedIncSubtensor )
106108def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
107- idxs = node .inputs [1 :] if isinstance (op , AdvancedSubtensor ) else node .inputs [2 :]
108- adv_idxs_dims = [
109- idx .type .ndim
109+ if isinstance (op , AdvancedSubtensor ):
110+ x , y , idxs = node .inputs [0 ], None , node .inputs [1 :]
111+ else :
112+ x , y , * idxs = node .inputs
113+
114+ basic_idxs = [
115+ idx
110116 for idx in idxs
111- if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
117+ if (
118+ isinstance (idx .type , NoneTypeT )
119+ or (isinstance (idx .type , SliceType ) and not is_full_slice (idx ))
120+ )
121+ ]
122+ adv_idxs = [
123+ {
124+ "axis" : i ,
125+ "dtype" : idx .type .dtype ,
126+ "bcast" : idx .type .broadcastable ,
127+ "ndim" : idx .type .ndim ,
128+ }
129+ for i , idx in enumerate (idxs )
130+ if isinstance (idx .type , TensorType )
112131 ]
113132
133+ # Special case for consecutive consecutive vector indices
134+ def broadcasted_to (x_bcast : tuple [bool , ...], to_bcast : tuple [bool , ...]):
135+ # Check that x is not broadcasted to y based on broadcastable info
136+ if len (x_bcast ) < len (to_bcast ):
137+ return True
138+ for x_bcast_dim , to_bcast_dim in zip (x_bcast , to_bcast , strict = True ):
139+ if x_bcast_dim and not to_bcast_dim :
140+ return True
141+ return False
142+
143+ # Special implementation for consecutive integer vector indices
144+ if (
145+ not basic_idxs
146+ and len (adv_idxs ) >= 2
147+ # Must be integer vectors
148+ # Todo: we could allow shape=(1,) if this is the shape of x
149+ and all (
150+ (adv_idx ["bcast" ] == (False ,) and adv_idx ["dtype" ] != "bool" )
151+ for adv_idx in adv_idxs
152+ )
153+ # Must be consecutive
154+ and not op .non_contiguous_adv_indexing (node )
155+ # y in set/inc_subtensor cannot be broadcasted
156+ and (
157+ y is None
158+ or not broadcasted_to (
159+ y .type .broadcastable ,
160+ (
161+ x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
162+ + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
163+ ),
164+ )
165+ )
166+ ):
167+ return numba_funcify_multiple_integer_vector_indexing (op , node , ** kwargs )
168+
169+ # Other cases not natively supported by Numba (fallback to obj-mode)
114170 if (
115171 # Numba does not support indexes with more than one dimension
172+ any (idx ["ndim" ] > 1 for idx in adv_idxs )
116173 # Nor multiple vector indexes
117- ( len ( adv_idxs_dims ) > 1 or adv_idxs_dims [ 0 ] > 1 )
118- # The default index implementation does not handle duplicate indices correctly
174+ or sum ( idx [ "ndim" ] > 0 for idx in adv_idxs ) > 1
175+ # The default PyTensor implementation does not handle duplicate indices correctly
119176 or (
120177 isinstance (op , AdvancedIncSubtensor )
121178 and not op .set_instead_of_inc
@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
124181 ):
125182 return generate_fallback_impl (op , node , ** kwargs )
126183
184+ # What's left should all be supported natively by numba
127185 return numba_funcify_default_subtensor (op , node , ** kwargs )
128186
129187
188+ def numba_funcify_multiple_integer_vector_indexing (
189+ op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
190+ ):
191+ # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
192+ if isinstance (op , AdvancedSubtensor ):
193+ y , idxs = None , node .inputs [1 :]
194+ else :
195+ y , * idxs = node .inputs [1 :]
196+
197+ first_axis = next (
198+ i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
199+ )
200+ try :
201+ after_last_axis = next (
202+ i
203+ for i , idx in enumerate (idxs [first_axis :], start = first_axis )
204+ if not isinstance (idx .type , TensorType )
205+ )
206+ except StopIteration :
207+ after_last_axis = len (idxs )
208+
209+ if isinstance (op , AdvancedSubtensor ):
210+
211+ @numba_njit
212+ def advanced_subtensor_multiple_vector (x , * idxs ):
213+ none_slices = idxs [:first_axis ]
214+ vec_idxs = idxs [first_axis :after_last_axis ]
215+
216+ x_shape = x .shape
217+ idx_shape = vec_idxs [0 ].shape
218+ shape_bef = x_shape [:first_axis ]
219+ shape_aft = x_shape [after_last_axis :]
220+ out_shape = (* shape_bef , * idx_shape , * shape_aft )
221+ out_buffer = np .empty (out_shape , dtype = x .dtype )
222+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
223+ out_buffer [(* none_slices , i )] = x [(* none_slices , * scalar_idxs )]
224+ return out_buffer
225+
226+ return advanced_subtensor_multiple_vector
227+
228+ elif op .set_instead_of_inc :
229+ inplace = op .inplace
230+
231+ @numba_njit
232+ def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
233+ vec_idxs = idxs [first_axis :after_last_axis ]
234+ x_shape = x .shape
235+
236+ if inplace :
237+ out = x
238+ else :
239+ out = x .copy ()
240+
241+ for outer in np .ndindex (x_shape [:first_axis ]):
242+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
243+ out [(* outer , * scalar_idxs )] = y [(* outer , i )]
244+ return out
245+
246+ return advanced_set_subtensor_multiple_vector
247+
248+ else :
249+ inplace = op .inplace
250+
251+ @numba_njit
252+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
253+ vec_idxs = idxs [first_axis :after_last_axis ]
254+ x_shape = x .shape
255+
256+ if inplace :
257+ out = x
258+ else :
259+ out = x .copy ()
260+
261+ for outer in np .ndindex (x_shape [:first_axis ]):
262+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
263+ out [(* outer , * scalar_idxs )] += y [(* outer , i )]
264+ return out
265+
266+ return advanced_inc_subtensor_multiple_vector
267+
268+
130269@numba_funcify .register (AdvancedIncSubtensor1 )
131270def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
132271 inplace = op .inplace
0 commit comments