@@ -20,6 +20,38 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
20
20
21
21
Py_XINCREF (na_object );
22
22
((StringDTypeObject * )new )-> na_object = na_object ;
23
+ int hasnull = na_object != NULL ;
24
+ int has_nan_na = 0 ;
25
+ int has_string_na = 0 ;
26
+ ss default_string = EMPTY_STRING ;
27
+ if (hasnull ) {
28
+ double na_float = PyFloat_AsDouble (na_object );
29
+ if (na_float == -1.0 && PyErr_Occurred ()) {
30
+ // not a float, still treat as nan if PyObject_IsTrue raises
31
+ // (e.g. pandas.NA)
32
+ PyErr_Clear ();
33
+ int is_truthy = PyObject_IsTrue (na_object );
34
+ if (is_truthy == -1 ) {
35
+ PyErr_Clear ();
36
+ has_nan_na = 1 ;
37
+ }
38
+ }
39
+ else if (npy_isnan (na_float )) {
40
+ has_nan_na = 1 ;
41
+ }
42
+
43
+ if (PyUnicode_Check (na_object )) {
44
+ has_string_na = 1 ;
45
+ Py_ssize_t size = 0 ;
46
+ const char * buf = PyUnicode_AsUTF8AndSize (na_object , & size );
47
+ default_string .len = size ;
48
+ // discards const, how to avoid?
49
+ default_string .buf = (char * )buf ;
50
+ }
51
+ }
52
+ ((StringDTypeObject * )new )-> has_nan_na = has_nan_na ;
53
+ ((StringDTypeObject * )new )-> has_string_na = has_string_na ;
54
+ ((StringDTypeObject * )new )-> default_string = default_string ;
23
55
((StringDTypeObject * )new )-> coerce = coerce ;
24
56
25
57
PyArray_Descr * base = (PyArray_Descr * )new ;
@@ -28,6 +60,9 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
28
60
base -> flags |= NPY_NEEDS_INIT ;
29
61
base -> flags |= NPY_LIST_PICKLE ;
30
62
base -> flags |= NPY_ITEM_REFCOUNT ;
63
+ if (hasnull && !(has_string_na && has_nan_na )) {
64
+ base -> flags |= NPY_NEEDS_PYAPI ;
65
+ }
31
66
32
67
return new ;
33
68
}
@@ -227,25 +262,43 @@ int
227
262
_compare (void * a , void * b , StringDTypeObject * descr )
228
263
{
229
264
int hasnull = descr -> na_object != NULL ;
265
+ int has_string_na = descr -> has_string_na ;
266
+ int has_nan_na = descr -> has_nan_na ;
267
+ if (hasnull && !(has_string_na && has_nan_na )) {
268
+ // check if an error occured already to avoid setting an error again
269
+ if (PyErr_Occurred ()) {
270
+ return 0 ;
271
+ }
272
+ }
273
+ const ss * default_string = & descr -> default_string ;
230
274
const ss * ss_a = (ss * )a ;
231
275
const ss * ss_b = (ss * )b ;
232
276
int a_is_null = ss_isnull (ss_a );
233
277
int b_is_null = ss_isnull (ss_b );
234
278
if (NPY_UNLIKELY (a_is_null || b_is_null )) {
235
- if (hasnull ) {
236
- if (a_is_null ) {
237
- return 1 ;
279
+ if (hasnull && !has_string_na ) {
280
+ if (has_nan_na ) {
281
+ if (a_is_null ) {
282
+ return 1 ;
283
+ }
284
+ else if (b_is_null ) {
285
+ return -1 ;
286
+ }
238
287
}
239
- else if (b_is_null ) {
240
- return -1 ;
288
+ else {
289
+ // we must hold the GIL in this branch
290
+ PyErr_SetString (
291
+ PyExc_ValueError ,
292
+ "Cannot compare null this is not a nan-like value" );
293
+ return 0 ;
241
294
}
242
295
}
243
296
else {
244
297
if (a_is_null ) {
245
- ss_a = & EMPTY_STRING ;
298
+ ss_a = default_string ;
246
299
}
247
300
if (b_is_null ) {
248
- ss_b = & EMPTY_STRING ;
301
+ ss_b = default_string ;
249
302
}
250
303
}
251
304
}
@@ -349,6 +402,94 @@ stringdtype_get_fill_zero_loop(void *NPY_UNUSED(traverse_context),
349
402
return 0 ;
350
403
}
351
404
405
+ static int
406
+ stringdtype_is_known_scalar_type (PyArray_DTypeMeta * NPY_UNUSED (cls ),
407
+ PyTypeObject * pytype )
408
+ {
409
+ if (pytype == & PyFloat_Type ) {
410
+ return 1 ;
411
+ }
412
+ if (pytype == & PyLong_Type ) {
413
+ return 1 ;
414
+ }
415
+ if (pytype == & PyBool_Type ) {
416
+ return 1 ;
417
+ }
418
+ if (pytype == & PyComplex_Type ) {
419
+ return 1 ;
420
+ }
421
+ if (pytype == & PyUnicode_Type ) {
422
+ return 1 ;
423
+ }
424
+ if (pytype == & PyBytes_Type ) {
425
+ return 1 ;
426
+ }
427
+ if (pytype == & PyBoolArrType_Type ) {
428
+ return 1 ;
429
+ }
430
+ if (pytype == & PyByteArrType_Type ) {
431
+ return 1 ;
432
+ }
433
+ if (pytype == & PyShortArrType_Type ) {
434
+ return 1 ;
435
+ }
436
+ if (pytype == & PyIntArrType_Type ) {
437
+ return 1 ;
438
+ }
439
+ if (pytype == & PyLongArrType_Type ) {
440
+ return 1 ;
441
+ }
442
+ if (pytype == & PyLongLongArrType_Type ) {
443
+ return 1 ;
444
+ }
445
+ if (pytype == & PyUByteArrType_Type ) {
446
+ return 1 ;
447
+ }
448
+ if (pytype == & PyUShortArrType_Type ) {
449
+ return 1 ;
450
+ }
451
+ if (pytype == & PyUIntArrType_Type ) {
452
+ return 1 ;
453
+ }
454
+ if (pytype == & PyULongArrType_Type ) {
455
+ return 1 ;
456
+ }
457
+ if (pytype == & PyULongLongArrType_Type ) {
458
+ return 1 ;
459
+ }
460
+ if (pytype == & PyHalfArrType_Type ) {
461
+ return 1 ;
462
+ }
463
+ if (pytype == & PyFloatArrType_Type ) {
464
+ return 1 ;
465
+ }
466
+ if (pytype == & PyDoubleArrType_Type ) {
467
+ return 1 ;
468
+ }
469
+ if (pytype == & PyLongDoubleArrType_Type ) {
470
+ return 1 ;
471
+ }
472
+ if (pytype == & PyCFloatArrType_Type ) {
473
+ return 1 ;
474
+ }
475
+ if (pytype == & PyCDoubleArrType_Type ) {
476
+ return 1 ;
477
+ }
478
+ if (pytype == & PyCLongDoubleArrType_Type ) {
479
+ return 1 ;
480
+ }
481
+ if (pytype == & PyIntpArrType_Type ) {
482
+ return 1 ;
483
+ }
484
+ if (pytype == & PyUIntpArrType_Type ) {
485
+ return 1 ;
486
+ }
487
+ if (pytype == & PyDatetimeArrType_Type ) {
488
+ return 1 ;
489
+ }
490
+ return 0 ;
491
+ }
492
+
352
493
static PyType_Slot StringDType_Slots [] = {
353
494
{NPY_DT_common_instance , & common_instance },
354
495
{NPY_DT_common_dtype , & common_dtype },
@@ -363,6 +504,7 @@ static PyType_Slot StringDType_Slots[] = {
363
504
{NPY_DT_PyArray_ArrFuncs_argmin , & argmin },
364
505
{NPY_DT_get_clear_loop , & stringdtype_get_clear_loop },
365
506
{NPY_DT_get_fill_zero_loop , & stringdtype_get_fill_zero_loop },
507
+ {_NPY_DT_is_known_scalar_type , & stringdtype_is_known_scalar_type },
366
508
{0 , NULL }};
367
509
368
510
static PyObject *
@@ -530,7 +672,7 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
530
672
// pointer equality catches pandas.NA and other NA singletons
531
673
eq = 1 ;
532
674
}
533
- else {
675
+ else if ( PyFloat_Check ( sna ) && PyFloat_Check ( ona )) {
534
676
// nan check catches np.nan and float('nan')
535
677
double sna_float = PyFloat_AsDouble (sna );
536
678
if (sna_float == -1.0 && PyErr_Occurred ()) {
@@ -543,13 +685,12 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
543
685
if (npy_isnan (sna_float ) && npy_isnan (ona_float )) {
544
686
eq = 1 ;
545
687
}
546
-
688
+ }
689
+ else {
547
690
// finally check if a python equals comparison returns True
548
- else if (PyObject_RichCompareBool (sna , ona , Py_EQ ) == 1 ) {
549
- eq = 1 ;
550
- }
551
- else {
552
- eq = 0 ;
691
+ eq = PyObject_RichCompareBool (sna , ona , Py_EQ );
692
+ if (eq == -1 ) {
693
+ return NULL ;
553
694
}
554
695
}
555
696
0 commit comments