@@ -2120,16 +2120,12 @@ def make_node(self, x, ilist):
21202120 out_shape = (ilist_ .type .shape [0 ], * x_ .type .shape [1 :])
21212121 return Apply (self , [x_ , ilist_ ], [TensorType (dtype = x .dtype , shape = out_shape )()])
21222122
2123- def perform (self , node , inp , out_ ):
2123+ def perform (self , node , inp , output_storage ):
21242124 x , i = inp
2125- (out ,) = out_
2126- # Copy always implied by numpy advanced indexing semantic.
2127- if out [0 ] is not None and out [0 ].shape == (len (i ),) + x .shape [1 :]:
2128- o = out [0 ]
2129- else :
2130- o = None
21312125
2132- out [0 ] = x .take (i , axis = 0 , out = o )
2126+ # Numpy take is always slower when out is provided
2127+ # https://github.com/numpy/numpy/issues/28636
2128+ output_storage [0 ][0 ] = x .take (i , axis = 0 , out = None )
21332129
21342130 def connection_pattern (self , node ):
21352131 rval = [[True ], * ([False ] for _ in node .inputs [1 :])]
@@ -2174,42 +2170,70 @@ def c_code(self, node, name, input_names, output_names, sub):
21742170 "c_code defined for AdvancedSubtensor1, not for child class" ,
21752171 type (self ),
21762172 )
2173+ x , idxs = node .inputs
2174+ shape0 = x .type .shape [0 ]
2175+ if (
2176+ shape0 is not None
2177+ and isinstance (idxs , Constant )
2178+ and (
2179+ (idxs .data .max () < shape0 )
2180+ and ((idxs .data .min () >= 0 ) or (idxs .data .min () > - shape0 ))
2181+ )
2182+ ):
2183+ # We can know ahead of time that all indices are valid, so we can use a faster mode
2184+ mode = "NPY_WRAP" # This seems to be faster than NPY_CLIP
2185+ else :
2186+ mode = "NPY_RAISE"
21772187 a_name , i_name = input_names [0 ], input_names [1 ]
21782188 output_name = output_names [0 ]
21792189 fail = sub ["fail" ]
2180- return f"""
2181- if ({ output_name } != NULL) {{
2182- npy_intp nd, i, *shape;
2183- nd = PyArray_NDIM({ a_name } ) + PyArray_NDIM({ i_name } ) - 1;
2184- if (PyArray_NDIM({ output_name } ) != nd) {{
2190+ if mode == "NPY_RAISE" :
2191+ # numpy_take always makes an intermediate copy if NPY_RAISE which is slower than just allocating a new buffer
2192+ # We can remove this special case after https://github.com/numpy/numpy/issues/28636
2193+ manage_pre_allocated_out = f"""
2194+ if ({ output_name } != NULL) {{
2195+ // Numpy TakeFrom is always slower when copying
2196+ // https://github.com/numpy/numpy/issues/28636
21852197 Py_CLEAR({ output_name } );
21862198 }}
2187- else {{
2188- shape = PyArray_DIMS( { output_name } );
2189- for (i = 0; i < PyArray_NDIM( { i_name } ); i++) {{
2190- if (shape[i] != PyArray_DIMS( { i_name } )[i] ) {{
2191- Py_CLEAR( { output_name } ) ;
2192- break;
2193- }}
2199+ """
2200+ else :
2201+ manage_pre_allocated_out = f"""
2202+ if ({ output_name } != NULL ) {{
2203+ npy_intp nd = PyArray_NDIM( { a_name } ) + PyArray_NDIM( { i_name } ) - 1 ;
2204+ if (PyArray_NDIM( { output_name } ) != nd) {{
2205+ Py_CLEAR( { output_name } );
21942206 }}
2195- if ({ output_name } != NULL) {{
2196- for (; i < nd; i++) {{
2197- if (shape[i] != PyArray_DIMS({ a_name } )[
2198- i-PyArray_NDIM({ i_name } )+1]) {{
2207+ else {{
2208+ int i;
2209+ npy_intp* shape = PyArray_DIMS({ output_name } );
2210+ for (i = 0; i < PyArray_NDIM({ i_name } ); i++) {{
2211+ if (shape[i] != PyArray_DIMS({ i_name } )[i]) {{
21992212 Py_CLEAR({ output_name } );
22002213 break;
22012214 }}
22022215 }}
2216+ if ({ output_name } != NULL) {{
2217+ for (; i < nd; i++) {{
2218+ if (shape[i] != PyArray_DIMS({ a_name } )[i-PyArray_NDIM({ i_name } )+1]) {{
2219+ Py_CLEAR({ output_name } );
2220+ break;
2221+ }}
2222+ }}
2223+ }}
22032224 }}
22042225 }}
2205- }}
2226+ """
2227+
2228+ return f"""
2229+ { manage_pre_allocated_out }
22062230 { output_name } = (PyArrayObject*)PyArray_TakeFrom(
2207- { a_name } , (PyObject*){ i_name } , 0, { output_name } , NPY_RAISE );
2231+ { a_name } , (PyObject*){ i_name } , 0, { output_name } , { mode } );
22082232 if ({ output_name } == NULL) { fail } ;
22092233 """
22102234
22112235 def c_code_cache_version (self ):
2212- return (4 ,)
2236+ return (5 ,)
22132237
22142238
22152239advanced_subtensor1 = AdvancedSubtensor1 ()
0 commit comments