@@ -60,7 +60,71 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
60
60
return NPY_SAFE_CASTING ;
61
61
}
62
62
63
- static char * equal_name = "string_equal" ;
63
+ /*
64
+ * Copied from NumPy, because NumPy doesn't always use it :)
65
+ */
66
+ static int
67
+ default_ufunc_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
68
+ PyArray_DTypeMeta * signature [],
69
+ PyArray_DTypeMeta * new_op_dtypes [])
70
+ {
71
+ /* If nin < 2 promotion is a no-op, so it should not be registered */
72
+ assert (ufunc -> nin > 1 );
73
+ if (op_dtypes [0 ] == NULL ) {
74
+ assert (ufunc -> nin == 2 && ufunc -> nout == 1 ); /* must be reduction */
75
+ Py_INCREF (op_dtypes [1 ]);
76
+ new_op_dtypes [0 ] = op_dtypes [1 ];
77
+ Py_INCREF (op_dtypes [1 ]);
78
+ new_op_dtypes [1 ] = op_dtypes [1 ];
79
+ Py_INCREF (op_dtypes [1 ]);
80
+ new_op_dtypes [2 ] = op_dtypes [1 ];
81
+ return 0 ;
82
+ }
83
+ PyArray_DTypeMeta * common = NULL ;
84
+ /*
85
+ * If a signature is used and homogeneous in its outputs use that
86
+ * (Could/should likely be rather applied to inputs also, although outs
87
+ * only could have some advantage and input dtypes are rarely enforced.)
88
+ */
89
+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
90
+ if (signature [i ] != NULL ) {
91
+ if (common == NULL ) {
92
+ Py_INCREF (signature [i ]);
93
+ common = signature [i ];
94
+ }
95
+ else if (common != signature [i ]) {
96
+ Py_CLEAR (common ); /* Not homogeneous, unset common */
97
+ break ;
98
+ }
99
+ }
100
+ }
101
+ /* Otherwise, use the common DType of all input operands */
102
+ if (common == NULL ) {
103
+ common = PyArray_PromoteDTypeSequence (ufunc -> nin , op_dtypes );
104
+ if (common == NULL ) {
105
+ if (PyErr_ExceptionMatches (PyExc_TypeError )) {
106
+ PyErr_Clear (); /* Do not propagate normal promotion errors */
107
+ }
108
+ return -1 ;
109
+ }
110
+ }
111
+
112
+ for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
113
+ PyArray_DTypeMeta * tmp = common ;
114
+ if (signature [i ]) {
115
+ tmp = signature [i ]; /* never replace a fixed one. */
116
+ }
117
+ Py_INCREF (tmp );
118
+ new_op_dtypes [i ] = tmp ;
119
+ }
120
+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
121
+ Py_XINCREF (op_dtypes [i ]);
122
+ new_op_dtypes [i ] = op_dtypes [i ];
123
+ }
124
+
125
+ Py_DECREF (common );
126
+ return 0 ;
127
+ }
64
128
65
129
int
66
130
init_equal_ufunc (PyObject * numpy )
@@ -73,36 +137,74 @@ init_equal_ufunc(PyObject *numpy)
73
137
/*
74
138
* Initialize spec for equality
75
139
*/
76
- PyArray_DTypeMeta * * eq_dtypes = malloc (3 * sizeof (PyArray_DTypeMeta * ));
77
- eq_dtypes [0 ] = & StringDType ;
78
- eq_dtypes [1 ] = & StringDType ;
79
- eq_dtypes [2 ] = & PyArray_BoolDType ;
140
+ PyArray_DTypeMeta * eq_dtypes [3 ] = {& StringDType , & StringDType ,
141
+ & PyArray_BoolDType };
80
142
81
143
static PyType_Slot eq_slots [] = {
82
144
{NPY_METH_resolve_descriptors , & string_equal_resolve_descriptors },
83
145
{NPY_METH_strided_loop , & string_equal_strided_loop },
84
146
{0 , NULL }};
85
147
86
- PyArrayMethod_Spec * EqualSpec = malloc (sizeof (PyArrayMethod_Spec ));
148
+ PyArrayMethod_Spec EqualSpec = {
149
+ .name = "string_equal" ,
150
+ .nin = 2 ,
151
+ .nout = 1 ,
152
+ .casting = NPY_NO_CASTING ,
153
+ .flags = 0 ,
154
+ .dtypes = eq_dtypes ,
155
+ .slots = eq_slots ,
156
+ };
157
+
158
+ if (PyUFunc_AddLoopFromSpec (equal , & EqualSpec ) < 0 ) {
159
+ Py_DECREF (equal );
160
+ return -1 ;
161
+ }
87
162
88
- EqualSpec -> name = equal_name ;
89
- EqualSpec -> nin = 2 ;
90
- EqualSpec -> nout = 1 ;
91
- EqualSpec -> casting = NPY_SAFE_CASTING ;
92
- EqualSpec -> flags = 0 ;
93
- EqualSpec -> dtypes = eq_dtypes ;
94
- EqualSpec -> slots = eq_slots ;
163
+ /*
164
+ * This might interfere with NumPy at this time.
165
+ */
166
+ PyObject * promoter_capsule1 = PyCapsule_New (
167
+ (void * )& default_ufunc_promoter , "numpy._ufunc_promoter" , NULL );
168
+ if (promoter_capsule1 == NULL ) {
169
+ return -1 ;
170
+ }
95
171
96
- if (PyUFunc_AddLoopFromSpec (equal , EqualSpec ) < 0 ) {
97
- Py_DECREF (equal );
98
- free (eq_dtypes );
99
- free (EqualSpec );
172
+ PyObject * DTypes1 = PyTuple_Pack (3 , & StringDType , & PyArray_UnicodeDType ,
173
+ & PyArrayDescr_Type );
174
+ if (DTypes1 == 0 ) {
175
+ Py_DECREF (promoter_capsule1 );
176
+ return -1 ;
177
+ }
178
+
179
+ if (PyUFunc_AddPromoter (equal , DTypes1 , promoter_capsule1 ) < 0 ) {
180
+ Py_DECREF (promoter_capsule1 );
181
+ Py_DECREF (DTypes1 );
182
+ return -1 ;
183
+ }
184
+ Py_DECREF (promoter_capsule1 );
185
+ Py_DECREF (DTypes1 );
186
+
187
+ PyObject * promoter_capsule2 = PyCapsule_New (
188
+ (void * )& default_ufunc_promoter , "numpy._ufunc_promoter" , NULL );
189
+ if (promoter_capsule2 == NULL ) {
190
+ return -1 ;
191
+ }
192
+ PyObject * DTypes2 = PyTuple_Pack (3 , & PyArray_UnicodeDType , & StringDType ,
193
+ & PyArrayDescr_Type );
194
+ if (DTypes2 == 0 ) {
195
+ Py_DECREF (promoter_capsule2 );
196
+ return -1 ;
197
+ }
198
+
199
+ if (PyUFunc_AddPromoter (equal , DTypes2 , promoter_capsule2 ) < 0 ) {
200
+ Py_DECREF (promoter_capsule2 );
201
+ Py_DECREF (DTypes2 );
100
202
return -1 ;
101
203
}
204
+ Py_DECREF (promoter_capsule2 );
205
+ Py_DECREF (DTypes2 );
102
206
103
207
Py_DECREF (equal );
104
- free (eq_dtypes );
105
- free (EqualSpec );
106
208
return 0 ;
107
209
}
108
210
0 commit comments