@@ -25,21 +25,7 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
25
25
int has_string_na = 0 ;
26
26
ss default_string = EMPTY_STRING ;
27
27
if (hasnull ) {
28
- double na_float = PyFloat_AsDouble (na_object );
29
- if (na_float == -1.0 && PyErr_Occurred ()) {
30
- // not a float, still treat as nan if PyObject_IsTrue raises
31
- // (e.g. pandas.NA)
32
- PyErr_Clear ();
33
- int is_truthy = PyObject_IsTrue (na_object );
34
- if (is_truthy == -1 ) {
35
- PyErr_Clear ();
36
- has_nan_na = 1 ;
37
- }
38
- }
39
- else if (npy_isnan (na_float )) {
40
- has_nan_na = 1 ;
41
- }
42
-
28
+ // first check for a string
43
29
if (PyUnicode_Check (na_object )) {
44
30
has_string_na = 1 ;
45
31
Py_ssize_t size = 0 ;
@@ -48,6 +34,25 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
48
34
// discards const, how to avoid?
49
35
default_string .buf = (char * )buf ;
50
36
}
37
+ else {
38
+ // treat as nan-like if != comparison returns a object whose truth
39
+ // value raises an error (pd.NA) or a truthy value (e.g. a
40
+ // NaN-like object)
41
+ PyObject * eq = PyObject_RichCompare (na_object , na_object , Py_NE );
42
+ if (eq == NULL ) {
43
+ Py_DECREF (new );
44
+ return NULL ;
45
+ }
46
+ int is_truthy = PyObject_IsTrue (na_object );
47
+ if (is_truthy == -1 ) {
48
+ PyErr_Clear ();
49
+ has_nan_na = 1 ;
50
+ }
51
+ else if (is_truthy == 1 ) {
52
+ has_nan_na = 1 ;
53
+ }
54
+ Py_DECREF (eq );
55
+ }
51
56
}
52
57
((StringDTypeObject * )new )-> has_nan_na = has_nan_na ;
53
58
((StringDTypeObject * )new )-> has_string_na = has_string_na ;
@@ -60,6 +65,9 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
60
65
base -> flags |= NPY_NEEDS_INIT ;
61
66
base -> flags |= NPY_LIST_PICKLE ;
62
67
base -> flags |= NPY_ITEM_REFCOUNT ;
68
+ // this is only because of error propagation in sorting, once this dtype
69
+ // lives inside numpy we can relax this and patch the sorting code
70
+ // directly.
63
71
if (hasnull && !(has_string_na && has_nan_na )) {
64
72
base -> flags |= NPY_NEEDS_PYAPI ;
65
73
}
@@ -302,7 +310,7 @@ _compare(void *a, void *b, StringDTypeObject *descr)
302
310
}
303
311
}
304
312
}
305
- return strcmp (ss_a -> buf , ss_b -> buf );
313
+ return sscmp (ss_a , ss_b );
306
314
}
307
315
308
316
// PyArray_ArgFunc
0 commit comments