@@ -5880,15 +5880,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58805880 PyArrayObject * op2_array = NULL ;
58815881 PyArrayMapIterObject * iter = NULL ;
58825882 PyArrayIterObject * iter2 = NULL ;
5883- PyArray_Descr * dtypes [3 ] = {NULL , NULL , NULL };
58845883 PyArrayObject * operands [3 ] = {NULL , NULL , NULL };
58855884 PyArrayObject * array_operands [3 ] = {NULL , NULL , NULL };
58865885
5887- int needs_api = 0 ;
5886+ PyArray_DTypeMeta * signature [3 ] = {NULL , NULL , NULL };
5887+ PyArray_DTypeMeta * operand_DTypes [3 ] = {NULL , NULL , NULL };
5888+ PyArray_Descr * operation_descrs [3 ] = {NULL , NULL , NULL };
58885889
5889- PyUFuncGenericFunction innerloop ;
5890- void * innerloopdata ;
5891- npy_intp i ;
58925890 int nop ;
58935891
58945892 /* override vars */
@@ -5901,6 +5899,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59015899 int buffersize ;
59025900 int errormask = 0 ;
59035901 char * err_msg = NULL ;
5902+
5903+ PyArrayMethod_StridedLoop * strided_loop ;
5904+ NpyAuxData * auxdata = NULL ;
5905+
59045906 NPY_BEGIN_THREADS_DEF ;
59055907
59065908 if (ufunc -> nin > 2 ) {
@@ -5988,26 +5990,51 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59885990
59895991 /*
59905992 * Create dtypes array for either one or two input operands.
5991- * The output operand is set to the first input operand
5993+ * Compare to the logic in `convert_ufunc_arguments`.
5994+ * TODO: It may be good to review some of this behaviour, since the
5995+ * operand array is special (it is written to) similar to reductions.
5996+ * Using unsafe-casting as done here, is likely not desirable.
59925997 */
59935998 operands [0 ] = op1_array ;
5999+ operand_DTypes [0 ] = NPY_DTYPE (PyArray_DESCR (op1_array ));
6000+ Py_INCREF (operand_DTypes [0 ]);
6001+ int force_legacy_promotion = 0 ;
6002+ int allow_legacy_promotion = NPY_DT_is_legacy (operand_DTypes [0 ]);
6003+
59946004 if (op2_array != NULL ) {
59956005 operands [1 ] = op2_array ;
5996- operands [2 ] = op1_array ;
6006+ operand_DTypes [1 ] = NPY_DTYPE (PyArray_DESCR (op2_array ));
6007+ Py_INCREF (operand_DTypes [1 ]);
6008+ allow_legacy_promotion &= NPY_DT_is_legacy (operand_DTypes [1 ]);
6009+ operands [2 ] = operands [0 ];
6010+ operand_DTypes [2 ] = operand_DTypes [0 ];
6011+ Py_INCREF (operand_DTypes [2 ]);
6012+
59976013 nop = 3 ;
6014+ if (allow_legacy_promotion && ((PyArray_NDIM (op1_array ) == 0 )
6015+ != (PyArray_NDIM (op2_array ) == 0 ))) {
6016+ /* both are legacy and only one is 0-D: force legacy */
6017+ force_legacy_promotion = should_use_min_scalar (2 , operands , 0 , NULL );
6018+ }
59986019 }
59996020 else {
6000- operands [1 ] = op1_array ;
6021+ operands [1 ] = operands [0 ];
6022+ operand_DTypes [1 ] = operand_DTypes [0 ];
6023+ Py_INCREF (operand_DTypes [1 ]);
60016024 operands [2 ] = NULL ;
60026025 nop = 2 ;
60036026 }
60046027
6005- if (ufunc -> type_resolver (ufunc , NPY_UNSAFE_CASTING ,
6006- operands , NULL , dtypes ) < 0 ) {
6028+ PyArrayMethodObject * ufuncimpl = promote_and_get_ufuncimpl (ufunc ,
6029+ operands , signature , operand_DTypes ,
6030+ force_legacy_promotion , allow_legacy_promotion );
6031+ if (ufuncimpl == NULL ) {
60076032 goto fail ;
60086033 }
6009- if (ufunc -> legacy_inner_loop_selector (ufunc , dtypes ,
6010- & innerloop , & innerloopdata , & needs_api ) < 0 ) {
6034+
6035+ /* Find the correct descriptors for the operation */
6036+ if (resolve_descriptors (nop , ufunc , ufuncimpl ,
6037+ operands , operation_descrs , signature , NPY_UNSAFE_CASTING ) < 0 ) {
60116038 goto fail ;
60126039 }
60136040
@@ -6068,21 +6095,44 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60686095 NPY_ITER_GROWINNER |
60696096 NPY_ITER_DELAY_BUFALLOC ,
60706097 NPY_KEEPORDER , NPY_UNSAFE_CASTING ,
6071- op_flags , dtypes ,
6098+ op_flags , operation_descrs ,
60726099 -1 , NULL , NULL , buffersize );
60736100
60746101 if (iter_buffer == NULL ) {
60756102 goto fail ;
60766103 }
60776104
6078- needs_api = needs_api | NpyIter_IterationNeedsAPI (iter_buffer );
6079-
60806105 iternext = NpyIter_GetIterNext (iter_buffer , NULL );
60816106 if (iternext == NULL ) {
60826107 NpyIter_Deallocate (iter_buffer );
60836108 goto fail ;
60846109 }
60856110
6111+ PyArrayMethod_Context context = {
6112+ .caller = (PyObject * )ufunc ,
6113+ .method = ufuncimpl ,
6114+ .descriptors = operation_descrs ,
6115+ };
6116+
6117+ NPY_ARRAYMETHOD_FLAGS flags ;
6118+ /* Use contiguous strides; if there is such a loop it may be faster */
6119+ npy_intp strides [3 ] = {
6120+ operation_descrs [0 ]-> elsize , operation_descrs [1 ]-> elsize , 0 };
6121+ if (nop == 3 ) {
6122+ strides [2 ] = operation_descrs [2 ]-> elsize ;
6123+ }
6124+
6125+ if (ufuncimpl -> get_strided_loop (& context , 1 , 0 , strides ,
6126+ & strided_loop , & auxdata , & flags ) < 0 ) {
6127+ goto fail ;
6128+ }
6129+ int needs_api = (flags & NPY_METH_REQUIRES_PYAPI ) != 0 ;
6130+ needs_api |= NpyIter_IterationNeedsAPI (iter_buffer );
6131+ if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6132+ /* Start with the floating-point exception flags cleared */
6133+ npy_clear_floatstatus_barrier ((char * )& iter );
6134+ }
6135+
60866136 if (!needs_api ) {
60876137 NPY_BEGIN_THREADS ;
60886138 }
@@ -6091,14 +6141,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60916141 * Iterate over first and second operands and call ufunc
60926142 * for each pair of inputs
60936143 */
6094- i = iter -> size ;
6095- while ( i > 0 )
6144+ int res = 0 ;
6145+ for ( npy_intp i = iter -> size ; i > 0 ; i -- )
60966146 {
60976147 char * dataptr [3 ];
60986148 char * * buffer_dataptr ;
60996149 /* one element at a time, no stride required but read by innerloop */
6100- npy_intp count [3 ] = {1 , 0xDEADBEEF , 0xDEADBEEF };
6101- npy_intp stride [3 ] = {0xDEADBEEF , 0xDEADBEEF , 0xDEADBEEF };
6150+ npy_intp count = 1 ;
61026151
61036152 /*
61046153 * Set up data pointers for either one or two input operands.
@@ -6117,14 +6166,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61176166 /* Reset NpyIter data pointers which will trigger a buffer copy */
61186167 NpyIter_ResetBasePointers (iter_buffer , dataptr , & err_msg );
61196168 if (err_msg ) {
6169+ res = -1 ;
61206170 break ;
61216171 }
61226172
61236173 buffer_dataptr = NpyIter_GetDataPtrArray (iter_buffer );
61246174
6125- innerloop (buffer_dataptr , count , stride , innerloopdata );
6126-
6127- if (needs_api && PyErr_Occurred ()) {
6175+ res = strided_loop (& context , buffer_dataptr , & count , strides , auxdata );
6176+ if (res != 0 ) {
61286177 break ;
61296178 }
61306179
@@ -6138,32 +6187,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61386187 if (iter2 != NULL ) {
61396188 PyArray_ITER_NEXT (iter2 );
61406189 }
6141-
6142- i -- ;
61436190 }
61446191
61456192 NPY_END_THREADS ;
61466193
6147- if (err_msg ) {
6194+ if (res != 0 && err_msg ) {
61486195 PyErr_SetString (PyExc_ValueError , err_msg );
61496196 }
6197+ if (res == 0 && !(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6198+ /* NOTE: We could check float errors even when `res < 0` */
6199+ res = _check_ufunc_fperr (errormask , NULL , "at" );
6200+ }
61506201
6202+ NPY_AUXDATA_FREE (auxdata );
61516203 NpyIter_Deallocate (iter_buffer );
61526204
61536205 Py_XDECREF (op2_array );
61546206 Py_XDECREF (iter );
61556207 Py_XDECREF (iter2 );
6156- for (i = 0 ; i < 3 ; i ++ ) {
6157- Py_XDECREF (dtypes [i ]);
6208+ for (int i = 0 ; i < 3 ; i ++ ) {
6209+ Py_XDECREF (operation_descrs [i ]);
61586210 Py_XDECREF (array_operands [i ]);
61596211 }
61606212
61616213 /*
6162- * An error should only be possible if needs_api is true, but this is not
6163- * strictly correct for old-style ufuncs (e.g. `power` released the GIL
6164- * but manually set an Exception).
6214+ * An error should only be possible if needs_api is true or `res != 0`,
6215+ * but this is not strictly correct for old-style ufuncs
6216+ * (e.g. `power` released the GIL but manually set an Exception).
61656217 */
6166- if (PyErr_Occurred ()) {
6218+ if (res != 0 || PyErr_Occurred ()) {
61676219 return NULL ;
61686220 }
61696221 else {
@@ -6178,10 +6230,11 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61786230 Py_XDECREF (op2_array );
61796231 Py_XDECREF (iter );
61806232 Py_XDECREF (iter2 );
6181- for (i = 0 ; i < 3 ; i ++ ) {
6182- Py_XDECREF (dtypes [i ]);
6233+ for (int i = 0 ; i < 3 ; i ++ ) {
6234+ Py_XDECREF (operation_descrs [i ]);
61836235 Py_XDECREF (array_operands [i ]);
61846236 }
6237+ NPY_AUXDATA_FREE (auxdata );
61856238
61866239 return NULL ;
61876240}
0 commit comments