@@ -2326,6 +2326,9 @@ def copy_of_x(self, x):
23262326 NPY_ARRAY_ENSURECOPY, NULL)"""
23272327
23282328 def c_support_code (self , ** kwargs ):
2329+ if numpy_version < "1.8.0" or using_numpy_2 :
2330+ return None
2331+
23292332 types = [
23302333 "npy_" + t
23312334 for t in [
@@ -2516,15 +2519,100 @@ def gen_num(typen):
25162519 return code
25172520
25182521 def c_code (self , node , name , input_names , output_names , sub ):
2519- if numpy_version < "1.8.0" or using_numpy_2 :
2520- raise NotImplementedError
2521-
25222522 x , y , idx = input_names
2523- out = output_names [ 0 ]
2523+ [ out ] = output_names
25242524 copy_of_x = self .copy_of_x (x )
25252525 params = sub ["params" ]
25262526 fail = sub ["fail" ]
25272527
2528+ x_ , y_ , idx_ = node .inputs
2529+ y_dtype = y_ .type .dtype_specs ()[1 ]
2530+ idx_dtype = idx_ .type .dtype_specs ()[1 ]
2531+ out_dtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2532+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2533+ if (
2534+ x_ .type .ndim == 1
2535+ and x_ .type .dtype not in complex_dtypes
2536+ and not y_bcast
2537+ and y_ .type .dtype not in complex_dtypes
2538+ ):
2539+ # Simple implementation for vector x, y cases
2540+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2541+ idx_may_be_invalid = AdvancedSubtensor1 ._idx_may_be_invalid (x_ , idx_ )
2542+ shape0 = x_ .type .shape [0 ]
2543+ # This is used to make sure that when we trust the indices to be valid
2544+ # we are not fooled by a wrong static shape
2545+ unexpected_shape0 = (
2546+ f"PyArray_SHAPE({ x } )[0] != { shape0 } " if shape0 is not None else "0"
2547+ )
2548+
2549+ op = "=" if self .set_instead_of_inc else "+="
2550+ code = f"""
2551+ if ({ params } ->inplace)
2552+ {{
2553+ if ({ x } != { out } )
2554+ {{
2555+ Py_XDECREF({ out } );
2556+ Py_INCREF({ x } );
2557+ { out } = { x } ;
2558+ }}
2559+ }}
2560+ else
2561+ {{
2562+ Py_XDECREF({ out } );
2563+ { out } = { copy_of_x } ;
2564+ if (!{ out } ) {{
2565+ // Exception already set
2566+ { fail }
2567+ }}
2568+ }}
2569+
2570+ if ((PyArray_NDIM({ out } ) != 1) || ({ unexpected_shape0 } )) {{
2571+ PyErr_SetString(PyExc_ValueError, "Input x to AdvancedIncSubtensor1 does not have right shape or ndim");
2572+ { fail }
2573+ }}
2574+ if (PyArray_NDIM({ idx } ) != 1) {{
2575+ PyErr_SetString(PyExc_ValueError, "Input idx to AdvancedIncSubtensor1 ndim != 1");
2576+ { fail }
2577+ }}
2578+ if ((PyArray_NDIM({ y } ) != 1) || (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0])) {{
2579+ PyErr_SetString(PyExc_ValueError, "Input y to AdvancedIncSubtensor1 does not have right shape or ndim");
2580+ { fail }
2581+ }}
2582+
2583+ {{
2584+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2585+ { out_dtype } * out_data = ({ out_dtype } *)PyArray_DATA({ out } );
2586+ { y_dtype } * y_data = ({ y_dtype } *)PyArray_DATA({ y } );
2587+ { idx_dtype } * idx_data = ({ idx_dtype } *)PyArray_DATA({ idx } );
2588+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2589+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2590+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2591+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2592+
2593+ for(int i = 0; i < n; i++){{
2594+ { idx_dtype } idx = idx_data[i * idx_jump];
2595+ if ({ int (idx_may_be_neg )} ){{
2596+ if (idx < 0) {{
2597+ idx += out_shape0;
2598+ }}
2599+ }}
2600+ if ({ int (idx_may_be_invalid )} ){{
2601+ if ((idx < 0) || (idx >= out_shape0)) {{
2602+ PyErr_Format(PyExc_IndexError,"index out of bounds");
2603+ { fail }
2604+ }}
2605+ }}
2606+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2607+ }}
2608+
2609+ }}
2610+ """
2611+ return code
2612+
2613+ if numpy_version < "1.8.0" or using_numpy_2 :
2614+ raise NotImplementedError
2615+
25282616 return f"""
25292617 PyObject* rval = NULL;
25302618 if ({ params } ->inplace)
@@ -2552,7 +2640,7 @@ def c_code(self, node, name, input_names, output_names, sub):
25522640 """
25532641
25542642 def c_code_cache_version (self ):
2555- return (8 ,)
2643+ return (9 ,)
25562644
25572645 def perform (self , node , inp , out_ ):
25582646 x , y , idx = inp
0 commit comments