@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
22622262 check_input = False
22632263 params_type = ParamsType (inplace = ps .bool , set_instead_of_inc = ps .bool )
22642264
2265+ _runtime_broadcast_error_msg = (
2266+ "Runtime broadcasting not allowed. "
2267+ "AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
2268+ "If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
2269+ )
2270+
22652271 def __init__ (self , inplace = False , set_instead_of_inc = False ):
22662272 self .inplace = bool (inplace )
22672273 self .set_instead_of_inc = bool (set_instead_of_inc )
@@ -2333,6 +2339,9 @@ def copy_of_x(self, x):
23332339 NPY_ARRAY_ENSURECOPY, NULL)"""
23342340
23352341 def c_support_code (self , ** kwargs ):
2342+ if numpy_version < "1.8.0" or using_numpy_2 :
2343+ return None
2344+
23362345 types = [
23372346 "npy_" + t
23382347 for t in [
@@ -2523,15 +2532,104 @@ def gen_num(typen):
25232532 return code
25242533
25252534 def c_code (self , node , name , input_names , output_names , sub ):
2526- if numpy_version < "1.8.0" or using_numpy_2 :
2527- raise NotImplementedError
2528-
25292535 x , y , idx = input_names
2530- out = output_names [ 0 ]
2536+ [ out ] = output_names
25312537 copy_of_x = self .copy_of_x (x )
25322538 params = sub ["params" ]
25332539 fail = sub ["fail" ]
25342540
2541+ x_ , y_ , idx_ = node .inputs
2542+ y_dtype = y_ .type .dtype_specs ()[1 ]
2543+ idx_dtype = idx_ .type .dtype_specs ()[1 ]
2544+ out_dtype = node .outputs [0 ].type .dtype_specs ()[1 ]
2545+ y_bcast = y_ .type .broadcastable != idx_ .type .broadcastable
2546+ if (
2547+ x_ .type .ndim == 1
2548+ and x_ .type .dtype not in complex_dtypes
2549+ and not y_bcast
2550+ and y_ .type .dtype not in complex_dtypes
2551+ ):
2552+ # Simple implementation for vector x, y cases
2553+ idx_may_be_neg = not (isinstance (idx_ , Constant ) and idx_ .data .min () >= 0 )
2554+ idx_may_be_invalid = AdvancedSubtensor1 ._idx_may_be_invalid (x_ , idx_ )
2555+ shape0 = x_ .type .shape [0 ]
2556+ # This is used to make sure that when we trust the indices to be valid
2557+ # we are not fooled by a wrong static shape
2558+ unexpected_shape0 = (
2559+ f"PyArray_SHAPE({ x } )[0] != { shape0 } " if shape0 is not None else "0"
2560+ )
2561+
2562+ op = "=" if self .set_instead_of_inc else "+="
2563+ code = f"""
2564+ if ({ params } ->inplace)
2565+ {{
2566+ if ({ x } != { out } )
2567+ {{
2568+ Py_XDECREF({ out } );
2569+ Py_INCREF({ x } );
2570+ { out } = { x } ;
2571+ }}
2572+ }}
2573+ else
2574+ {{
2575+ Py_XDECREF({ out } );
2576+ { out } = { copy_of_x } ;
2577+ if (!{ out } ) {{
2578+ // Exception already set
2579+ { fail }
2580+ }}
2581+ }}
2582+
2583+ if ((PyArray_NDIM({ out } ) != 1) || ({ unexpected_shape0 } )) {{
2584+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) does not have right shape or ndim");
2585+ { fail }
2586+ }}
2587+ if (PyArray_NDIM({ idx } ) != 1) {{
2588+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim != 1");
2589+ { fail }
2590+ }}
2591+ if ((PyArray_NDIM({ y } ) != 1) || (PyArray_SHAPE({ y } )[0] != PyArray_SHAPE({ idx } )[0])) {{
2592+ if ((PyArray_NDIM({ y } ) == 1) && (PyArray_SHAPE({ y } )[0] == 1)){{
2593+ PyErr_SetString(PyExc_ValueError, "{ self ._runtime_broadcast_error_msg } ");
2594+ }} else {{
2595+ PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match");
2596+ }}
2597+ { fail }
2598+ }}
2599+
2600+ {{
2601+ npy_intp out_shape0 = PyArray_SHAPE({ out } )[0];
2602+ { out_dtype } * out_data = ({ out_dtype } *)PyArray_DATA({ out } );
2603+ { y_dtype } * y_data = ({ y_dtype } *)PyArray_DATA({ y } );
2604+ { idx_dtype } * idx_data = ({ idx_dtype } *)PyArray_DATA({ idx } );
2605+ npy_intp n = PyArray_SHAPE({ idx } )[0];
2606+ npy_intp out_jump = PyArray_STRIDES({ out } )[0] / PyArray_ITEMSIZE({ out } );
2607+ npy_intp y_jump = PyArray_STRIDES({ y } )[0] / PyArray_ITEMSIZE({ y } );
2608+ npy_intp idx_jump = PyArray_STRIDES({ idx } )[0] / PyArray_ITEMSIZE({ idx } );
2609+
2610+ for(int i = 0; i < n; i++){{
2611+ { idx_dtype } idx = idx_data[i * idx_jump];
2612+ if ({ int (idx_may_be_neg )} ){{
2613+ if (idx < 0) {{
2614+ idx += out_shape0;
2615+ }}
2616+ }}
2617+ if ({ int (idx_may_be_invalid )} ){{
2618+ if ((idx < 0) || (idx >= out_shape0)) {{
2619+ PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx, out_shape0);
2620+ { fail }
2621+ }}
2622+ }}
2623+ out_data[idx * out_jump] { op } y_data[i * y_jump];
2624+ }}
2625+
2626+ }}
2627+ """
2628+ return code
2629+
2630+ if numpy_version < "1.8.0" or using_numpy_2 :
2631+ raise NotImplementedError
2632+
25352633 return f"""
25362634 PyObject* rval = NULL;
25372635 if ({ params } ->inplace)
@@ -2559,22 +2657,43 @@ def c_code(self, node, name, input_names, output_names, sub):
25592657 """
25602658
25612659 def c_code_cache_version (self ):
2562- return (8 ,)
2660+ return (9 ,)
2661+
2662+ def _check_runtime_broadcasting (self , node , x , y , idx ):
2663+ if y .ndim > 0 :
2664+ y_pt_bcast = node .inputs [1 ].broadcastable
2665+
2666+ if not y_pt_bcast [0 ] and y .shape [0 ] == 1 and y .shape [0 ] != idx .shape [0 ]:
2667+ # Attempting to broadcast with index
2668+ raise ValueError (self ._runtime_broadcast_error_msg )
2669+ if any (
2670+ not y_bcast and y_dim == 1 and y_dim != x_dim
2671+ for y_bcast , y_dim , x_dim in zip (
2672+ reversed (y_pt_bcast ),
2673+ reversed (y .shape ),
2674+ reversed (x .shape ),
2675+ strict = False ,
2676+ )
2677+ ):
2678+ # Attempting to broadcast with buffer
2679+ raise ValueError (self ._runtime_broadcast_error_msg )
2680+
2681+ def perform (self , node , inputs , output_storage ):
2682+ x , y , idx = inputs
25632683
2564- def perform (self , node , inp , out_ ):
2565- x , y , idx = inp
2566- (out ,) = out_
25672684 if not self .inplace :
25682685 x = x .copy ()
25692686
2687+ self ._check_runtime_broadcasting (node , x , y , idx )
2688+
25702689 if self .set_instead_of_inc :
25712690 x [idx ] = y
25722691 else :
25732692 # In Numpy, `x[idx] += y` doesn't work if the same index is present
25742693 # many times: it does it only once.
25752694 np .add .at (x , idx , y )
25762695
2577- out [0 ] = x
2696+ output_storage [ 0 ] [0 ] = x
25782697
25792698 def infer_shape (self , fgraph , node , ishapes ):
25802699 x , y , ilist = ishapes
0 commit comments