@@ -2320,6 +2320,9 @@ def copy_of_x(self, x):
23202320 NPY_ARRAY_ENSURECOPY, NULL)"""
23212321
23222322 def c_support_code (self , ** kwargs ):
2323+ if numpy_version < "1.8.0" or using_numpy_2 :
2324+ return None
2325+
23232326 types = [
23242327 "npy_" + t
23252328 for t in [
@@ -2510,15 +2513,105 @@ def gen_num(typen):
25102513 return code
25112514
25122515 def c_code (self , node , name , input_names , output_names , sub ):
2513- if numpy_version < "1.8.0" or using_numpy_2 :
2514- raise NotImplementedError
2515-
25162516 x , y , idx = input_names
2517- out = output_names [ 0 ]
2517+ [ out ] = output_names
25182518 copy_of_x = self .copy_of_x (x )
25192519 params = sub ["params" ]
25202520 fail = sub ["fail" ]
25212521
2522+ x_ , y_ , idx_ = node .inputs
2523+ y_dtype = y_ .type .dtype_specs ()[1 ]
2524+ idx_dtype = idx_ .type .dtype_specs ()[1 ]
2525+ out_dtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2526+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2527+ if (
2528+ x_ .type .ndim == 1
2529+ and x_ .type .dtype not in complex_dtypes
2530+ and not y_bcast
2531+ and y_ .type .dtype not in complex_dtypes
2532+ ):
2533+ # Simple implementation for vector x, y cases
2534+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2535+ shape0 = x_ .type .shape [0 ]
2536+ idx_may_be_invalid = not (
2537+ shape0 is not None
2538+ and isinstance (idx_ , Constant )
2539+ and (idx_ .data .min () > 0 or idx_ .data .min () >= - shape0 )
2540+ and (idx_ .data .max () < 0 or idx_ .data .max () < shape0 )
2541+ )
2542+ # This is used to make sure that when we trust the indices to be valid
2543+ # we are not fooled by a wrong static shape
2544+ unexpected_shape0 = (
2545+ f"PyArray_SHAPE({ x } )[0] != { shape0 } " if shape0 is not None else "0"
2546+ )
2547+
2548+ op = "=" if self .set_instead_of_inc else "+="
2549+ code = f"""
2550+ if ({ params } ->inplace)
2551+ {{
2552+ if ({ x } != { out } )
2553+ {{
2554+ Py_XDECREF({ out } );
2555+ Py_INCREF({ x } );
2556+ { out } = { x } ;
2557+ }}
2558+ }}
2559+ else
2560+ {{
2561+ Py_XDECREF({ out } );
2562+ { out } = { copy_of_x } ;
2563+ if (!{ out } ) {{
2564+ // Exception already set
2565+ { fail }
2566+ }}
2567+ }}
2568+
2569+ if ((PyArray_NDIM({ out } ) != 1) || ({ unexpected_shape0 } )) {{
2570+ PyErr_SetString(PyExc_ValueError, "Input x to AdvancedIncSubtensor1 does not have right shape or ndim");
2571+ { fail }
2572+ }}
2573+ if (PyArray_NDIM({ idx } ) != 1) {{
2574+ PyErr_SetString(PyExc_ValueError, "Input idx to AdvancedIncSubtensor1 ndim != 1");
2575+ { fail }
2576+ }}
2577+ if ((PyArray_NDIM({ y } ) != 1) || (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0])) {{
2578+ PyErr_SetString(PyExc_ValueError, "Input y to AdvancedIncSubtensor1 does not have right shape or ndim");
2579+ { fail }
2580+ }}
2581+
2582+ {{
2583+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2584+ { out_dtype } * out_data = ({ out_dtype } *)PyArray_DATA({ out } );
2585+ { y_dtype } * y_data = ({ y_dtype } *)PyArray_DATA({ y } );
2586+ { idx_dtype } * idx_data = ({ idx_dtype } *)PyArray_DATA({ idx } );
2587+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2588+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2589+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2590+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2591+
2592+ for(int i = 0; i < n; i++){{
2593+ { idx_dtype } idx = idx_data[i * idx_jump];
2594+ if ({ int (idx_may_be_neg )} ){{
2595+ if (idx < 0) {{
2596+ idx += out_shape0;
2597+ }}
2598+ }}
2599+ if ({ int (idx_may_be_invalid )} ){{
2600+ if ((idx < 0) || (idx >= out_shape0)) {{
2601+ PyErr_Format(PyExc_IndexError,"index out of bounds");
2602+ { fail }
2603+ }}
2604+ }}
2605+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2606+ }}
2607+
2608+ }}
2609+ """
2610+ return code
2611+
2612+ if numpy_version < "1.8.0" or using_numpy_2 :
2613+ raise NotImplementedError
2614+
25222615 return f"""
25232616 PyObject* rval = NULL;
25242617 if ({ params } ->inplace)
@@ -2546,7 +2639,8 @@ def c_code(self, node, name, input_names, output_names, sub):
25462639 """
25472640
25482641 def c_code_cache_version (self ):
2549- return (8 ,)
2642+ return None
2643+ return (9 ,)
25502644
25512645 def perform (self , node , inp , out_ ):
25522646 x , y , idx = inp
0 commit comments