55from pytensor .link .numba .dispatch .basic import generate_fallback_impl , numba_njit
66from pytensor .link .utils import compile_function_src , unique_name_generator
77from pytensor .tensor import TensorType
8+ from pytensor .tensor .rewriting .subtensor import is_full_slice
89from pytensor .tensor .subtensor import (
910 AdvancedIncSubtensor ,
1011 AdvancedIncSubtensor1 ,
1314 IncSubtensor ,
1415 Subtensor ,
1516)
17+ from pytensor .tensor .type_other import NoneTypeT , SliceType
1618
1719
1820@numba_funcify .register (Subtensor )
@@ -104,18 +106,61 @@ def {function_name}({", ".join(input_names)}):
104106@numba_funcify .register (AdvancedSubtensor )
105107@numba_funcify .register (AdvancedIncSubtensor )
106108def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
107- idxs = node .inputs [1 :] if isinstance (op , AdvancedSubtensor ) else node .inputs [2 :]
108- adv_idxs_dims = [
109- idx .type .ndim
109+ if isinstance (op , AdvancedSubtensor ):
110+ x , y , idxs = node .inputs [0 ], None , node .inputs [1 :]
111+ else :
112+ x , y , * idxs = node .inputs
113+
114+ basic_idxs = [
115+ idx
110116 for idx in idxs
111- if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
117+ if (
118+ isinstance (idx .type , NoneTypeT )
119+ or (isinstance (idx .type , SliceType ) and not is_full_slice (idx ))
120+ )
121+ ]
122+ adv_idxs = [
123+ {
124+ "axis" : i ,
125+ "dtype" : idx .type .dtype ,
126+ "bcast" : idx .type .broadcastable ,
127+ "ndim" : idx .type .ndim ,
128+ }
129+ for i , idx in enumerate (idxs )
130+ if isinstance (idx .type , TensorType )
112131 ]
113132
133+ # Special case for consecutive consecutive vector indices
134+ if (
135+ not basic_idxs
136+ and len (adv_idxs ) >= 2
137+ # Must be integer vectors
138+ # Todo: we could allow shape=(1,) if this is the shape of x
139+ and all (
140+ (adv_idx ["bcast" ] == (False ,) and adv_idx ["dtype" ] != "bool" )
141+ for adv_idx in adv_idxs
142+ )
143+ # Must be consecutive
144+ and not op .non_contiguous_adv_indexing (node )
145+ # y in set/inc_subtensor cannot be broadcasted
146+ and (
147+ y is None
148+ or y .type .broadcastable
149+ == (
150+ x .type .broadcastable [: adv_idxs [0 ]["axis" ]]
151+ + x .type .broadcastable [adv_idxs [- 1 ]["axis" ] :]
152+ )
153+ )
154+ ):
155+ return numba_funcify_multiple_vector_indexing (op , node , ** kwargs )
156+
157+ # Cases natively supported by Numba
114158 if (
115159 # Numba does not support indexes with more than one dimension
160+ any (idx ["ndim" ] > 1 for idx in adv_idxs )
116161 # Nor multiple vector indexes
117- ( len ( adv_idxs_dims ) > 1 or adv_idxs_dims [ 0 ] > 1 )
118- # The default index implementation does not handle duplicate indices correctly
162+ or sum ( idx [ "ndim" ] > 0 for idx in adv_idxs ) > 1
163+ # The default PyTensor implementation does not handle duplicate indices correctly
119164 or (
120165 isinstance (op , AdvancedIncSubtensor )
121166 and not op .set_instead_of_inc
@@ -127,6 +172,87 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
127172 return numba_funcify_default_subtensor (op , node , ** kwargs )
128173
129174
175+ def numba_funcify_multiple_vector_indexing (
176+ op : AdvancedSubtensor | AdvancedIncSubtensor , node , ** kwargs
177+ ):
178+ # Special-case implementation for multiple consecutive vector indices (and set/incsubtensor)
179+ if isinstance (op , AdvancedSubtensor ):
180+ y , idxs = None , node .inputs [1 :]
181+ else :
182+ y , * idxs = node .inputs [1 :]
183+
184+ first_axis = next (
185+ i for i , idx in enumerate (idxs ) if isinstance (idx .type , TensorType )
186+ )
187+ try :
188+ after_last_axis = next (
189+ i
190+ for i , idx in enumerate (idxs [first_axis :], start = first_axis )
191+ if not isinstance (idx .type , TensorType )
192+ )
193+ except StopIteration :
194+ after_last_axis = len (idxs )
195+
196+ if isinstance (op , AdvancedSubtensor ):
197+
198+ @numba_njit
199+ def advanced_subtensor_multiple_vector (x , * idxs ):
200+ none_slices = idxs [:first_axis ]
201+ vec_idxs = idxs [first_axis :after_last_axis ]
202+
203+ x_shape = x .shape
204+ idx_shape = vec_idxs [0 ].shape
205+ shape_bef = x_shape [:first_axis ]
206+ shape_aft = x_shape [after_last_axis :]
207+ out_shape = (* shape_bef , * idx_shape , * shape_aft )
208+ out_buffer = np .empty (out_shape , dtype = x .dtype )
209+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
210+ out_buffer [(* none_slices , i )] = x [(* none_slices , * scalar_idxs )]
211+ return out_buffer
212+
213+ return advanced_subtensor_multiple_vector
214+
215+ elif op .set_instead_of_inc :
216+ inplace = op .inplace
217+
218+ @numba_njit
219+ def advanced_set_subtensor_multiple_vector (x , y , * idxs ):
220+ vec_idxs = idxs [first_axis :after_last_axis ]
221+ x_shape = x .shape
222+
223+ if inplace :
224+ out = x
225+ else :
226+ out = x .copy ()
227+
228+ for outer in np .ndindex (x_shape [:first_axis ]):
229+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
230+ out [(* outer , * scalar_idxs )] = y [* outer , i ]
231+ return out
232+
233+ return advanced_set_subtensor_multiple_vector
234+
235+ else :
236+ inplace = op .inplace
237+
238+ @numba_njit
239+ def advanced_inc_subtensor_multiple_vector (x , y , * idxs ):
240+ vec_idxs = idxs [first_axis :after_last_axis ]
241+ x_shape = x .shape
242+
243+ if inplace :
244+ out = x
245+ else :
246+ out = x .copy ()
247+
248+ for outer in np .ndindex (x_shape [:first_axis ]):
249+ for i , scalar_idxs in enumerate (zip (* vec_idxs )): # noqa: B905
250+ out [(* outer , * scalar_idxs )] += y [* outer , i ]
251+ return out
252+
253+ return advanced_inc_subtensor_multiple_vector
254+
255+
130256@numba_funcify .register (AdvancedIncSubtensor1 )
131257def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
132258 inplace = op .inplace
0 commit comments