@@ -5865,15 +5865,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58655865 PyArrayObject * op2_array = NULL ;
58665866 PyArrayMapIterObject * iter = NULL ;
58675867 PyArrayIterObject * iter2 = NULL ;
5868- PyArray_Descr * dtypes [3 ] = {NULL , NULL , NULL };
58695868 PyArrayObject * operands [3 ] = {NULL , NULL , NULL };
58705869 PyArrayObject * array_operands [3 ] = {NULL , NULL , NULL };
58715870
5872- int needs_api = 0 ;
5871+ PyArray_DTypeMeta * signature [3 ] = {NULL , NULL , NULL };
5872+ PyArray_DTypeMeta * operand_DTypes [3 ] = {NULL , NULL , NULL };
5873+ PyArray_Descr * operation_descrs [3 ] = {NULL , NULL , NULL };
58735874
5874- PyUFuncGenericFunction innerloop ;
5875- void * innerloopdata ;
5876- npy_intp i ;
58775875 int nop ;
58785876
58795877 /* override vars */
@@ -5886,6 +5884,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58865884 int buffersize ;
58875885 int errormask = 0 ;
58885886 char * err_msg = NULL ;
5887+
5888+ PyArrayMethod_StridedLoop * strided_loop ;
5889+ NpyAuxData * auxdata = NULL ;
5890+
58895891 NPY_BEGIN_THREADS_DEF ;
58905892
58915893 if (ufunc -> nin > 2 ) {
@@ -5973,26 +5975,51 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59735975
59745976 /*
59755977 * Create dtypes array for either one or two input operands.
5976- * The output operand is set to the first input operand
5978+ * Compare to the logic in `convert_ufunc_arguments`.
5979+ * TODO: It may be good to review some of this behaviour, since the
5980+ * operand array is special (it is written to) similar to reductions.
5981+ * Using unsafe-casting as done here, is likely not desirable.
59775982 */
59785983 operands [0 ] = op1_array ;
5984+ operand_DTypes [0 ] = NPY_DTYPE (PyArray_DESCR (op1_array ));
5985+ Py_INCREF (operand_DTypes [0 ]);
5986+ int force_legacy_promotion = 0 ;
5987+ int allow_legacy_promotion = NPY_DT_is_legacy (operand_DTypes [0 ]);
5988+
59795989 if (op2_array != NULL ) {
59805990 operands [1 ] = op2_array ;
5981- operands [2 ] = op1_array ;
5991+ operand_DTypes [1 ] = NPY_DTYPE (PyArray_DESCR (op2_array ));
5992+ Py_INCREF (operand_DTypes [1 ]);
5993+ allow_legacy_promotion &= NPY_DT_is_legacy (operand_DTypes [1 ]);
5994+ operands [2 ] = operands [0 ];
5995+ operand_DTypes [2 ] = operand_DTypes [0 ];
5996+ Py_INCREF (operand_DTypes [2 ]);
5997+
59825998 nop = 3 ;
5999+ if (allow_legacy_promotion && ((PyArray_NDIM (op1_array ) == 0 )
6000+ != (PyArray_NDIM (op2_array ) == 0 ))) {
6001+ /* both are legacy and only one is 0-D: force legacy */
6002+ force_legacy_promotion = should_use_min_scalar (2 , operands , 0 , NULL );
6003+ }
59836004 }
59846005 else {
5985- operands [1 ] = op1_array ;
6006+ operands [1 ] = operands [0 ];
6007+ operand_DTypes [1 ] = operand_DTypes [0 ];
6008+ Py_INCREF (operand_DTypes [1 ]);
59866009 operands [2 ] = NULL ;
59876010 nop = 2 ;
59886011 }
59896012
5990- if (ufunc -> type_resolver (ufunc , NPY_UNSAFE_CASTING ,
5991- operands , NULL , dtypes ) < 0 ) {
6013+ PyArrayMethodObject * ufuncimpl = promote_and_get_ufuncimpl (ufunc ,
6014+ operands , signature , operand_DTypes ,
6015+ force_legacy_promotion , allow_legacy_promotion );
6016+ if (ufuncimpl == NULL ) {
59926017 goto fail ;
59936018 }
5994- if (ufunc -> legacy_inner_loop_selector (ufunc , dtypes ,
5995- & innerloop , & innerloopdata , & needs_api ) < 0 ) {
6019+
6020+ /* Find the correct descriptors for the operation */
6021+ if (resolve_descriptors (nop , ufunc , ufuncimpl ,
6022+ operands , operation_descrs , signature , NPY_UNSAFE_CASTING ) < 0 ) {
59966023 goto fail ;
59976024 }
59986025
@@ -6053,21 +6080,44 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60536080 NPY_ITER_GROWINNER |
60546081 NPY_ITER_DELAY_BUFALLOC ,
60556082 NPY_KEEPORDER , NPY_UNSAFE_CASTING ,
6056- op_flags , dtypes ,
6083+ op_flags , operation_descrs ,
60576084 -1 , NULL , NULL , buffersize );
60586085
60596086 if (iter_buffer == NULL ) {
60606087 goto fail ;
60616088 }
60626089
6063- needs_api = needs_api | NpyIter_IterationNeedsAPI (iter_buffer );
6064-
60656090 iternext = NpyIter_GetIterNext (iter_buffer , NULL );
60666091 if (iternext == NULL ) {
60676092 NpyIter_Deallocate (iter_buffer );
60686093 goto fail ;
60696094 }
60706095
6096+ PyArrayMethod_Context context = {
6097+ .caller = (PyObject * )ufunc ,
6098+ .method = ufuncimpl ,
6099+ .descriptors = operation_descrs ,
6100+ };
6101+
6102+ NPY_ARRAYMETHOD_FLAGS flags ;
6103+ /* Use contiguous strides; if there is such a loop it may be faster */
6104+ npy_intp strides [3 ] = {
6105+ operation_descrs [0 ]-> elsize , operation_descrs [1 ]-> elsize , 0 };
6106+ if (nop == 3 ) {
6107+ strides [2 ] = operation_descrs [2 ]-> elsize ;
6108+ }
6109+
6110+ if (ufuncimpl -> get_strided_loop (& context , 1 , 0 , strides ,
6111+ & strided_loop , & auxdata , & flags ) < 0 ) {
6112+ goto fail ;
6113+ }
6114+ int needs_api = (flags & NPY_METH_REQUIRES_PYAPI ) != 0 ;
6115+ needs_api |= NpyIter_IterationNeedsAPI (iter_buffer );
6116+ if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6117+ /* Start with the floating-point exception flags cleared */
6118+ npy_clear_floatstatus_barrier ((char * )& iter );
6119+ }
6120+
60716121 if (!needs_api ) {
60726122 NPY_BEGIN_THREADS ;
60736123 }
@@ -6076,14 +6126,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60766126 * Iterate over first and second operands and call ufunc
60776127 * for each pair of inputs
60786128 */
6079- i = iter -> size ;
6080- while ( i > 0 )
6129+ int res = 0 ;
6130+ for ( npy_intp i = iter -> size ; i > 0 ; i -- )
60816131 {
60826132 char * dataptr [3 ];
60836133 char * * buffer_dataptr ;
60846134 /* one element at a time, no stride required but read by innerloop */
6085- npy_intp count [3 ] = {1 , 0xDEADBEEF , 0xDEADBEEF };
6086- npy_intp stride [3 ] = {0xDEADBEEF , 0xDEADBEEF , 0xDEADBEEF };
6135+ npy_intp count = 1 ;
60876136
60886137 /*
60896138 * Set up data pointers for either one or two input operands.
@@ -6102,14 +6151,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61026151 /* Reset NpyIter data pointers which will trigger a buffer copy */
61036152 NpyIter_ResetBasePointers (iter_buffer , dataptr , & err_msg );
61046153 if (err_msg ) {
6154+ res = -1 ;
61056155 break ;
61066156 }
61076157
61086158 buffer_dataptr = NpyIter_GetDataPtrArray (iter_buffer );
61096159
6110- innerloop (buffer_dataptr , count , stride , innerloopdata );
6111-
6112- if (needs_api && PyErr_Occurred ()) {
6160+ res = strided_loop (& context , buffer_dataptr , & count , strides , auxdata );
6161+ if (res != 0 ) {
61136162 break ;
61146163 }
61156164
@@ -6123,32 +6172,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61236172 if (iter2 != NULL ) {
61246173 PyArray_ITER_NEXT (iter2 );
61256174 }
6126-
6127- i -- ;
61286175 }
61296176
61306177 NPY_END_THREADS ;
61316178
6132- if (err_msg ) {
6179+ if (res != 0 && err_msg ) {
61336180 PyErr_SetString (PyExc_ValueError , err_msg );
61346181 }
6182+ if (res == 0 && !(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6183+ /* NOTE: We could check float errors even when `res < 0` */
6184+ res = _check_ufunc_fperr (errormask , NULL , "at" );
6185+ }
61356186
6187+ NPY_AUXDATA_FREE (auxdata );
61366188 NpyIter_Deallocate (iter_buffer );
61376189
61386190 Py_XDECREF (op2_array );
61396191 Py_XDECREF (iter );
61406192 Py_XDECREF (iter2 );
6141- for (i = 0 ; i < 3 ; i ++ ) {
6142- Py_XDECREF (dtypes [i ]);
6193+ for (int i = 0 ; i < 3 ; i ++ ) {
6194+ Py_XDECREF (operation_descrs [i ]);
61436195 Py_XDECREF (array_operands [i ]);
61446196 }
61456197
61466198 /*
6147- * An error should only be possible if needs_api is true, but this is not
6148- * strictly correct for old-style ufuncs (e.g. `power` released the GIL
6149- * but manually set an Exception).
6199+ * An error should only be possible if needs_api is true or `res != 0`,
6200+ * but this is not strictly correct for old-style ufuncs
6201+ * (e.g. `power` released the GIL but manually set an Exception).
61506202 */
6151- if (PyErr_Occurred ()) {
6203+ if (res != 0 || PyErr_Occurred ()) {
61526204 return NULL ;
61536205 }
61546206 else {
@@ -6163,10 +6215,11 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61636215 Py_XDECREF (op2_array );
61646216 Py_XDECREF (iter );
61656217 Py_XDECREF (iter2 );
6166- for (i = 0 ; i < 3 ; i ++ ) {
6167- Py_XDECREF (dtypes [i ]);
6218+ for (int i = 0 ; i < 3 ; i ++ ) {
6219+ Py_XDECREF (operation_descrs [i ]);
61686220 Py_XDECREF (array_operands [i ]);
61696221 }
6222+ NPY_AUXDATA_FREE (auxdata );
61706223
61716224 return NULL ;
61726225}
0 commit comments