@@ -164,8 +164,47 @@ string_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
164
164
return 0 ;
165
165
}
166
166
167
+ static int
168
+ string_not_equal_strided_loop (PyArrayMethod_Context * NPY_UNUSED (context ),
169
+ char * const data [], npy_intp const dimensions [],
170
+ npy_intp const strides [],
171
+ NpyAuxData * NPY_UNUSED (auxdata ))
172
+ {
173
+ npy_intp N = dimensions [0 ];
174
+ char * in1 = data [0 ];
175
+ char * in2 = data [1 ];
176
+ npy_bool * out = (npy_bool * )data [2 ];
177
+ npy_intp in1_stride = strides [0 ];
178
+ npy_intp in2_stride = strides [1 ];
179
+ npy_intp out_stride = strides [2 ];
180
+
181
+ ss * s1 = NULL , * s2 = NULL ;
182
+
183
+ while (N -- ) {
184
+ s1 = (ss * )in1 ;
185
+ s2 = (ss * )in2 ;
186
+ if (ss_isnull (s1 ) || ss_isnull (s2 )) {
187
+ // s1 or s2 is NA
188
+ * out = (npy_bool )0 ;
189
+ }
190
+ else if (s1 -> len == s2 -> len &&
191
+ strncmp (s1 -> buf , s2 -> buf , s1 -> len ) == 0 ) {
192
+ * out = (npy_bool )0 ;
193
+ }
194
+ else {
195
+ * out = (npy_bool )1 ;
196
+ }
197
+
198
+ in1 += in1_stride ;
199
+ in2 += in2_stride ;
200
+ out += out_stride ;
201
+ }
202
+
203
+ return 0 ;
204
+ }
205
+
167
206
static NPY_CASTING
168
- string_equal_resolve_descriptors (
207
+ string_comparison_resolve_descriptors (
169
208
struct PyArrayMethodObject_tag * NPY_UNUSED (method ),
170
209
PyArray_DTypeMeta * NPY_UNUSED (dtypes []), PyArray_Descr * given_descrs [],
171
210
PyArray_Descr * loop_descrs [], npy_intp * NPY_UNUSED (view_offset ))
@@ -227,9 +266,10 @@ string_isnan_resolve_descriptors(
227
266
* Copied from NumPy, because NumPy doesn't always use it :)
228
267
*/
229
268
static int
230
- default_ufunc_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
231
- PyArray_DTypeMeta * signature [],
232
- PyArray_DTypeMeta * new_op_dtypes [])
269
+ ufunc_promoter_internal (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
270
+ PyArray_DTypeMeta * signature [],
271
+ PyArray_DTypeMeta * new_op_dtypes [],
272
+ PyArray_DTypeMeta * final_dtype )
233
273
{
234
274
/* If nin < 2 promotion is a no-op, so it should not be registered */
235
275
assert (ufunc -> nin > 1 );
@@ -261,19 +301,11 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
261
301
}
262
302
}
263
303
}
264
- /* Otherwise, use the common DType of all input operands */
265
- if (common == NULL ) {
266
- common = PyArray_PromoteDTypeSequence (ufunc -> nin , op_dtypes );
267
- if (common == NULL ) {
268
- if (PyErr_ExceptionMatches (PyExc_TypeError )) {
269
- PyErr_Clear (); /* Do not propagate normal promotion errors */
270
- }
271
- return -1 ;
272
- }
273
- }
304
+ Py_XDECREF (common );
274
305
306
+ /* Otherwise, set all input operands to StringDType */
275
307
for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
276
- PyArray_DTypeMeta * tmp = common ;
308
+ PyArray_DTypeMeta * tmp = final_dtype ;
277
309
if (signature [i ]) {
278
310
tmp = signature [i ]; /* never replace a fixed one. */
279
311
}
@@ -285,10 +317,27 @@ default_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
285
317
new_op_dtypes [i ] = op_dtypes [i ];
286
318
}
287
319
288
- Py_DECREF (common );
289
320
return 0 ;
290
321
}
291
322
323
+ static int
324
+ string_ufunc_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
325
+ PyArray_DTypeMeta * signature [],
326
+ PyArray_DTypeMeta * new_op_dtypes [])
327
+ {
328
+ return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
329
+ (PyArray_DTypeMeta * )& StringDType );
330
+ }
331
+
332
+ static int
333
+ pandas_string_ufunc_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
334
+ PyArray_DTypeMeta * signature [],
335
+ PyArray_DTypeMeta * new_op_dtypes [])
336
+ {
337
+ return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
338
+ (PyArray_DTypeMeta * )& PandasStringDType );
339
+ }
340
+
292
341
// Register a ufunc.
293
342
//
294
343
// Pass NULL for resolve_func to use the default_resolve_descriptors.
@@ -334,23 +383,33 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
334
383
335
384
int
336
385
add_promoter (PyObject * numpy , const char * ufunc_name ,
337
- PyArray_DTypeMeta * * dtypes )
386
+ PyArray_DTypeMeta * ldtype , PyArray_DTypeMeta * rdtype ,
387
+ PyArray_DTypeMeta * edtype , int is_pandas )
338
388
{
339
389
PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
340
390
341
391
if (ufunc == NULL ) {
342
392
return -1 ;
343
393
}
344
394
345
- PyObject * DType_tuple = PyTuple_Pack (3 , dtypes [ 0 ], dtypes [ 1 ], dtypes [ 2 ] );
395
+ PyObject * DType_tuple = PyTuple_Pack (3 , ldtype , rdtype , edtype );
346
396
347
397
if (DType_tuple == NULL ) {
348
398
Py_DECREF (ufunc );
349
399
return -1 ;
350
400
}
351
401
352
- PyObject * promoter_capsule = PyCapsule_New ((void * )& default_ufunc_promoter ,
353
- "numpy._ufunc_promoter" , NULL );
402
+ PyObject * promoter_capsule = NULL ;
403
+
404
+ if (is_pandas == 0 ) {
405
+ promoter_capsule = PyCapsule_New ((void * )& string_ufunc_promoter ,
406
+ "numpy._ufunc_promoter" , NULL );
407
+ }
408
+ else {
409
+ promoter_capsule = PyCapsule_New ((void * )& pandas_string_ufunc_promoter ,
410
+ "numpy._ufunc_promoter" , NULL );
411
+ }
412
+
354
413
355
414
if (promoter_capsule == NULL ) {
356
415
Py_DECREF (ufunc );
@@ -380,30 +439,46 @@ init_ufuncs(void)
380
439
return -1 ;
381
440
}
382
441
383
- PyArray_DTypeMeta * eq_dtypes [] = {(PyArray_DTypeMeta * )& StringDType ,
384
- (PyArray_DTypeMeta * )& StringDType ,
385
- & PyArray_BoolDType };
442
+ PyArray_DTypeMeta * comparison_dtypes [] = {(PyArray_DTypeMeta * )& StringDType ,
443
+ (PyArray_DTypeMeta * )& StringDType ,
444
+ & PyArray_BoolDType };
386
445
387
- if (init_ufunc (numpy , "equal" , eq_dtypes ,
388
- & string_equal_resolve_descriptors ,
446
+ if (init_ufunc (numpy , "equal" , comparison_dtypes ,
447
+ & string_comparison_resolve_descriptors ,
389
448
& string_equal_strided_loop , "string_equal" , 2 , 1 ,
390
449
NPY_NO_CASTING , 0 ) < 0 ) {
391
450
goto error ;
392
451
}
393
452
394
- PyArray_DTypeMeta * promoter_dtypes [2 ][3 ] = {
395
- {(PyArray_DTypeMeta * )& StringDType , & PyArray_UnicodeDType ,
396
- & PyArray_BoolDType },
397
- {& PyArray_UnicodeDType , (PyArray_DTypeMeta * )& StringDType ,
398
- & PyArray_BoolDType },
399
- };
400
-
401
- if (add_promoter (numpy , "equal" , promoter_dtypes [0 ]) < 0 ) {
453
+ if (init_ufunc (numpy , "not_equal" , comparison_dtypes ,
454
+ & string_comparison_resolve_descriptors ,
455
+ & string_not_equal_strided_loop , "string_not_equal" , 2 , 1 ,
456
+ NPY_NO_CASTING , 0 ) < 0 ) {
402
457
goto error ;
403
458
}
404
459
405
- if (add_promoter (numpy , "equal" , promoter_dtypes [1 ]) < 0 ) {
406
- goto error ;
460
+ char * ufunc_names [2 ] = {"equal" , "not_equal" };
461
+
462
+ for (int i = 0 ; i < 2 ; i ++ ) {
463
+ if (add_promoter (numpy , ufunc_names [i ], (PyArray_DTypeMeta * )& StringDType ,
464
+ & PyArray_UnicodeDType , & PyArray_BoolDType , 0 ) < 0 ) {
465
+ goto error ;
466
+ }
467
+
468
+ if (add_promoter (numpy , ufunc_names [i ], & PyArray_UnicodeDType ,
469
+ (PyArray_DTypeMeta * )& StringDType , & PyArray_BoolDType , 0 ) < 0 ) {
470
+ goto error ;
471
+ }
472
+
473
+ if (add_promoter (numpy , ufunc_names [i ], & PyArray_ObjectDType ,
474
+ (PyArray_DTypeMeta * )& StringDType , & PyArray_BoolDType , 0 ) < 0 ) {
475
+ goto error ;
476
+ }
477
+
478
+ if (add_promoter (numpy , ufunc_names [i ], (PyArray_DTypeMeta * )& StringDType ,
479
+ & PyArray_ObjectDType , & PyArray_BoolDType , 0 ) < 0 ) {
480
+ goto error ;
481
+ }
407
482
}
408
483
409
484
PyArray_DTypeMeta * isnan_dtypes [] = {(PyArray_DTypeMeta * )& StringDType ,
@@ -448,30 +523,45 @@ init_ufuncs(void)
448
523
goto finish ;
449
524
}
450
525
451
- PyArray_DTypeMeta * peq_dtypes [] = {(PyArray_DTypeMeta * )& PandasStringDType ,
452
- (PyArray_DTypeMeta * )& PandasStringDType ,
453
- & PyArray_BoolDType };
526
+ PyArray_DTypeMeta * p_comparison_dtypes [] =
527
+ {(PyArray_DTypeMeta * )& PandasStringDType ,
528
+ (PyArray_DTypeMeta * )& PandasStringDType ,
529
+ & PyArray_BoolDType };
454
530
455
- if (init_ufunc (numpy , "equal" , peq_dtypes ,
456
- & string_equal_resolve_descriptors ,
531
+ if (init_ufunc (numpy , "equal" , p_comparison_dtypes ,
532
+ & string_comparison_resolve_descriptors ,
457
533
& string_equal_strided_loop , "string_equal" , 2 , 1 ,
458
534
NPY_NO_CASTING , 0 ) < 0 ) {
459
535
goto error ;
460
536
}
461
537
462
- PyArray_DTypeMeta * p_promoter_dtypes [2 ][3 ] = {
463
- {(PyArray_DTypeMeta * )& PandasStringDType , & PyArray_UnicodeDType ,
464
- & PyArray_BoolDType },
465
- {& PyArray_UnicodeDType , (PyArray_DTypeMeta * )& PandasStringDType ,
466
- & PyArray_BoolDType },
467
- };
468
-
469
- if (add_promoter (numpy , "equal" , p_promoter_dtypes [0 ]) < 0 ) {
538
+ if (init_ufunc (numpy , "not_equal" , p_comparison_dtypes ,
539
+ & string_comparison_resolve_descriptors ,
540
+ & string_not_equal_strided_loop , "string_not_equal" , 2 , 1 ,
541
+ NPY_NO_CASTING , 0 ) < 0 ) {
470
542
goto error ;
471
543
}
472
544
473
- if (add_promoter (numpy , "equal" , p_promoter_dtypes [1 ]) < 0 ) {
474
- goto error ;
545
+ for (int i = 0 ; i < 2 ; i ++ ) {
546
+ if (add_promoter (numpy , ufunc_names [i ], (PyArray_DTypeMeta * )& PandasStringDType ,
547
+ & PyArray_UnicodeDType , & PyArray_BoolDType , 1 ) < 0 ) {
548
+ goto error ;
549
+ }
550
+
551
+ if (add_promoter (numpy , ufunc_names [i ], & PyArray_UnicodeDType ,
552
+ (PyArray_DTypeMeta * )& PandasStringDType , & PyArray_BoolDType , 1 ) < 0 ) {
553
+ goto error ;
554
+ }
555
+
556
+ if (add_promoter (numpy , ufunc_names [i ], & PyArray_ObjectDType ,
557
+ (PyArray_DTypeMeta * )& PandasStringDType , & PyArray_BoolDType , 1 ) < 0 ) {
558
+ goto error ;
559
+ }
560
+
561
+ if (add_promoter (numpy , ufunc_names [i ], (PyArray_DTypeMeta * )& PandasStringDType ,
562
+ & PyArray_ObjectDType , & PyArray_BoolDType , 1 ) < 0 ) {
563
+ goto error ;
564
+ }
475
565
}
476
566
477
567
PyArray_DTypeMeta * p_isnan_dtypes [] = {
0 commit comments