@@ -60,7 +60,71 @@ string_equal_resolve_descriptors(PyObject *NPY_UNUSED(self),
60
60
return NPY_SAFE_CASTING ;
61
61
}
62
62
63
- static char * equal_name = "string_equal" ;
63
+ /*
64
+ * Copied from NumPy, because NumPy doesn't always use it :)
65
+ */
66
+ static int
67
+ default_ufunc_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
68
+ PyArray_DTypeMeta * signature [],
69
+ PyArray_DTypeMeta * new_op_dtypes [])
70
+ {
71
+ /* If nin < 2 promotion is a no-op, so it should not be registered */
72
+ assert (ufunc -> nin > 1 );
73
+ if (op_dtypes [0 ] == NULL ) {
74
+ assert (ufunc -> nin == 2 && ufunc -> nout == 1 ); /* must be reduction */
75
+ Py_INCREF (op_dtypes [1 ]);
76
+ new_op_dtypes [0 ] = op_dtypes [1 ];
77
+ Py_INCREF (op_dtypes [1 ]);
78
+ new_op_dtypes [1 ] = op_dtypes [1 ];
79
+ Py_INCREF (op_dtypes [1 ]);
80
+ new_op_dtypes [2 ] = op_dtypes [1 ];
81
+ return 0 ;
82
+ }
83
+ PyArray_DTypeMeta * common = NULL ;
84
+ /*
85
+ * If a signature is used and homogeneous in its outputs use that
86
+ * (Could/should likely be rather applied to inputs also, although outs
87
+ * only could have some advantage and input dtypes are rarely enforced.)
88
+ */
89
+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
90
+ if (signature [i ] != NULL ) {
91
+ if (common == NULL ) {
92
+ Py_INCREF (signature [i ]);
93
+ common = signature [i ];
94
+ }
95
+ else if (common != signature [i ]) {
96
+ Py_CLEAR (common ); /* Not homogeneous, unset common */
97
+ break ;
98
+ }
99
+ }
100
+ }
101
+ /* Otherwise, use the common DType of all input operands */
102
+ if (common == NULL ) {
103
+ common = PyArray_PromoteDTypeSequence (ufunc -> nin , op_dtypes );
104
+ if (common == NULL ) {
105
+ if (PyErr_ExceptionMatches (PyExc_TypeError )) {
106
+ PyErr_Clear (); /* Do not propagate normal promotion errors */
107
+ }
108
+ return -1 ;
109
+ }
110
+ }
111
+
112
+ for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
113
+ PyArray_DTypeMeta * tmp = common ;
114
+ if (signature [i ]) {
115
+ tmp = signature [i ]; /* never replace a fixed one. */
116
+ }
117
+ Py_INCREF (tmp );
118
+ new_op_dtypes [i ] = tmp ;
119
+ }
120
+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
121
+ Py_XINCREF (op_dtypes [i ]);
122
+ new_op_dtypes [i ] = op_dtypes [i ];
123
+ }
124
+
125
+ Py_DECREF (common );
126
+ return 0 ;
127
+ }
64
128
65
129
int
66
130
init_equal_ufunc (PyObject * numpy )
@@ -71,38 +135,66 @@ init_equal_ufunc(PyObject *numpy)
71
135
}
72
136
73
137
/*
74
- * Initialize spec for equality
138
+ * Initialize spec for equality
75
139
*/
76
- PyArray_DTypeMeta * * eq_dtypes = malloc (3 * sizeof (PyArray_DTypeMeta * ));
77
- eq_dtypes [0 ] = & StringDType ;
78
- eq_dtypes [1 ] = & StringDType ;
79
- eq_dtypes [2 ] = & PyArray_BoolDType ;
140
+ PyArray_DTypeMeta * eq_dtypes [3 ] = {& StringDType , & StringDType ,
141
+ & PyArray_BoolDType };
80
142
81
143
static PyType_Slot eq_slots [] = {
82
144
{NPY_METH_resolve_descriptors , & string_equal_resolve_descriptors },
83
145
{NPY_METH_strided_loop , & string_equal_strided_loop },
84
146
{0 , NULL }};
85
147
86
- PyArrayMethod_Spec * EqualSpec = malloc (sizeof (PyArrayMethod_Spec ));
148
+ PyArrayMethod_Spec EqualSpec = {
149
+ .name = "string_equal" ,
150
+ .nin = 2 ,
151
+ .nout = 1 ,
152
+ .casting = NPY_NO_CASTING ,
153
+ .flags = 0 ,
154
+ .dtypes = eq_dtypes ,
155
+ .slots = eq_slots ,
156
+ };
157
+
158
+ if (PyUFunc_AddLoopFromSpec (equal , & EqualSpec ) < 0 ) {
159
+ Py_DECREF (equal );
160
+ return -1 ;
161
+ }
87
162
88
- EqualSpec -> name = equal_name ;
89
- EqualSpec -> nin = 2 ;
90
- EqualSpec -> nout = 1 ;
91
- EqualSpec -> casting = NPY_SAFE_CASTING ;
92
- EqualSpec -> flags = 0 ;
93
- EqualSpec -> dtypes = eq_dtypes ;
94
- EqualSpec -> slots = eq_slots ;
163
+ /*
164
+ * Add promoter to ufunc, ensures operations that mix StringDType and
165
+ * UnicodeDType cast the unicode argument to string.
166
+ */
95
167
96
- if (PyUFunc_AddLoopFromSpec (equal , EqualSpec ) < 0 ) {
168
+ PyObject * DTypes [] = {
169
+ PyTuple_Pack (3 , & StringDType , & PyArray_UnicodeDType ,
170
+ & PyArray_BoolDType ),
171
+ PyTuple_Pack (3 , & PyArray_UnicodeDType , & StringDType ,
172
+ & PyArray_BoolDType ),
173
+ };
174
+
175
+ if ((DTypes [0 ] == NULL ) || (DTypes [1 ] == NULL )) {
97
176
Py_DECREF (equal );
98
- free (eq_dtypes );
99
- free (EqualSpec );
100
177
return -1 ;
101
178
}
102
179
180
+ PyObject * promoter_capsule = PyCapsule_New ((void * )& default_ufunc_promoter ,
181
+ "numpy._ufunc_promoter" , NULL );
182
+
183
+ for (int i = 0 ; i < 2 ; i ++ ) {
184
+ if (PyUFunc_AddPromoter (equal , DTypes [i ], promoter_capsule ) < 0 ) {
185
+ Py_DECREF (promoter_capsule );
186
+ Py_DECREF (DTypes [0 ]);
187
+ Py_DECREF (DTypes [1 ]);
188
+ Py_DECREF (equal );
189
+ return -1 ;
190
+ }
191
+ }
192
+
193
+ Py_DECREF (promoter_capsule );
194
+ Py_DECREF (DTypes [0 ]);
195
+ Py_DECREF (DTypes [1 ]);
103
196
Py_DECREF (equal );
104
- free (eq_dtypes );
105
- free (EqualSpec );
197
+
106
198
return 0 ;
107
199
}
108
200
0 commit comments