@@ -2257,6 +2257,12 @@ class AdvancedIncSubtensor1(COp):
22572257 check_input = False
22582258 params_type = ParamsType (inplace = ps .bool , set_instead_of_inc = ps .bool )
22592259
2260+ _runtime_broadcast_error_msg = (
2261+ "Runtime broadcasting not allowed. "
2262+ "AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2263+ "If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2264+ )
2265+
22602266 def __init__ (self , inplace = False , set_instead_of_inc = False ):
22612267 self .inplace = bool (inplace )
22622268 self .set_instead_of_inc = bool (set_instead_of_inc )
@@ -2328,6 +2334,9 @@ def copy_of_x(self, x):
23282334 NPY_ARRAY_ENSURECOPY, NULL)"""
23292335
23302336 def c_support_code (self , ** kwargs ):
2337+ if numpy_version < "1.8.0" or using_numpy_2 :
2338+ return None
2339+
23312340 types = [
23322341 "npy_" + t
23332342 for t in [
@@ -2518,15 +2527,104 @@ def gen_num(typen):
25182527 return code
25192528
25202529 def c_code (self , node , name , input_names , output_names , sub ):
2521- if numpy_version < "1.8.0" or using_numpy_2 :
2522- raise NotImplementedError
2523-
25242530 x , y , idx = input_names
2525- out = output_names [ 0 ]
2531+ [ out ] = output_names
25262532 copy_of_x = self .copy_of_x (x )
25272533 params = sub ["params" ]
25282534 fail = sub ["fail" ]
25292535
2536+ x_ , y_ , idx_ = node .inputs
2537+ y_dtype = y_ .type .dtype_specs ()[1 ]
2538+ idx_dtype = idx_ .type .dtype_specs ()[1 ]
2539+ out_dtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2540+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2541+ if (
2542+ x_ .type .ndim == 1
2543+ and x_ .type .dtype not in complex_dtypes
2544+ and not y_bcast
2545+ and y_ .type .dtype not in complex_dtypes
2546+ ):
2547+ # Simple implementation for vector x, y cases
2548+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2549+ idx_may_be_invalid = AdvancedSubtensor1 ._idx_may_be_invalid (x_ , idx_ )
2550+ shape0 = x_ .type .shape [0 ]
2551+ # This is used to make sure that when we trust the indices to be valid
2552+ # we are not fooled by a wrong static shape
2553+ unexpected_shape0 = (
2554+ f"PyArray_SHAPE({ x } )[0] != { shape0 } " if shape0 is not None else "0"
2555+ )
2556+
2557+ op = "=" if self .set_instead_of_inc else "+="
2558+ code = f"""
2559+ if ({ params } ->inplace)
2560+ {{
2561+ if ({ x } != { out } )
2562+ {{
2563+ Py_XDECREF({ out } );
2564+ Py_INCREF({ x } );
2565+ { out } = { x } ;
2566+ }}
2567+ }}
2568+ else
2569+ {{
2570+ Py_XDECREF({ out } );
2571+ { out } = { copy_of_x } ;
2572+ if (!{ out } ) {{
2573+ // Exception already set
2574+ { fail }
2575+ }}
2576+ }}
2577+
2578+ if ((PyArray_NDIM({ out } ) != 1) || ({ unexpected_shape0 } )) {{
2579+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
2580+ { fail }
2581+ }}
2582+ if (PyArray_NDIM({ idx } ) != 1) {{
2583+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim != 1");
2584+ { fail }
2585+ }}
2586+ if ((PyArray_NDIM({ y } ) != 1) || (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0])) {{
2587+ if ((PyArray_NDIM({ y } ) == 1) && (PyArray_SHAPE({ y } )[0] == 1)){{
2588+ PyErr_SetString(PyExc_ValueError, "{ self ._runtime_broadcast_error_msg } ");
2589+ }} else {{
2590+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match");
2591+ }}
2592+ { fail }
2593+ }}
2594+
2595+ {{
2596+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2597+ { out_dtype } * out_data = ({ out_dtype } *)PyArray_DATA({ out } );
2598+ { y_dtype } * y_data = ({ y_dtype } *)PyArray_DATA({ y } );
2599+ { idx_dtype } * idx_data = ({ idx_dtype } *)PyArray_DATA({ idx } );
2600+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2601+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2602+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2603+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2604+
2605+ for(int i = 0; i < n; i++){{
2606+ { idx_dtype } idx = idx_data[i * idx_jump];
2607+ if ({ int (idx_may_be_neg )} ){{
2608+ if (idx < 0) {{
2609+ idx += out_shape0;
2610+ }}
2611+ }}
2612+ if ({ int (idx_may_be_invalid )} ){{
2613+ if ((idx < 0) || (idx >= out_shape0)) {{
2614+ PyErr_Format(PyExc_IndexError,"index out of bounds");
2615+ { fail }
2616+ }}
2617+ }}
2618+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2619+ }}
2620+
2621+ }}
2622+ """
2623+ return code
2624+
2625+ if numpy_version < "1.8.0" or using_numpy_2 :
2626+ raise NotImplementedError
2627+
25302628 return f"""
25312629 PyObject* rval = NULL;
25322630 if ({ params } ->inplace)
@@ -2554,22 +2652,43 @@ def c_code(self, node, name, input_names, output_names, sub):
25542652 """
25552653
25562654 def c_code_cache_version (self ):
2557- return (8 ,)
2655+ return (9 ,)
2656+
2657+ def _check_runtime_broadcasting (self , node , x , y , idx ):
2658+ if y .ndim > 0 :
2659+ y_pt_bcast = node .inputs [1 ].broadcastable
2660+
2661+ if not y_pt_bcast [0 ] and y .shape [0 ] == 1 and y .shape [0 ] != idx .shape [0 ]:
2662+ # Attempting to broadcast with index
2663+ raise ValueError (self ._runtime_broadcast_error_msg )
2664+ if any (
2665+ not y_bcast and y_dim == 1 and y_dim != x_dim
2666+ for y_bcast , y_dim , x_dim in zip (
2667+ reversed (y_pt_bcast ),
2668+ reversed (y .shape ),
2669+ reversed (x .shape ),
2670+ strict = False ,
2671+ )
2672+ ):
2673+ # Attempting to broadcast with buffer
2674+ raise ValueError (self ._runtime_broadcast_error_msg )
2675+
2676+ def perform (self , node , inputs , output_storage ):
2677+ x , y , idx = inputs
25582678
2559- def perform (self , node , inp , out_ ):
2560- x , y , idx = inp
2561- (out ,) = out_
25622679 if not self .inplace :
25632680 x = x .copy ()
25642681
2682+ self ._check_runtime_broadcasting (node , x , y , idx )
2683+
25652684 if self .set_instead_of_inc :
25662685 x [idx ] = y
25672686 else :
25682687 # In Numpy, `x[idx] += y` doesn't work if the same index is present
25692688 # many times: it does it only once.
25702689 np .add .at (x , idx , y )
25712690
2572- out [0 ] = x
2691+ output_storage [ 0 ] [0 ] = x
25732692
25742693 def infer_shape (self , fgraph , node , ishapes ):
25752694 x , y , ilist = ishapes
0 commit comments