Skip to content

Commit a7724a1

Browse files
committed
make StringDType_richcompare do the right thing when na_object=nan
1 parent e437ff8 commit a7724a1

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,39 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
519519
StringDTypeObject *sself = (StringDTypeObject *)self;
520520
StringDTypeObject *sother = (StringDTypeObject *)other;
521521

522-
int eq = (sself->na_object == sother->na_object) &&
523-
(sself->coerce == sother->coerce);
522+
int eq;
523+
PyObject *sna = sself->na_object;
524+
PyObject *ona = sother->na_object;
525+
526+
if (sself->coerce != sother->coerce) {
527+
eq = 0;
528+
}
529+
else if (sna == ona) {
530+
// pointer equality catches pandas.NA and other NA singletons
531+
eq = 1;
532+
}
533+
else {
534+
// nan check catches np.nan and float('nan')
535+
double sna_float = PyFloat_AsDouble(sna);
536+
if (sna_float == -1.0 && PyErr_Occurred()) {
537+
return NULL;
538+
}
539+
double ona_float = PyFloat_AsDouble(ona);
540+
if (ona_float == -1.0 && PyErr_Occurred()) {
541+
return NULL;
542+
}
543+
if (npy_isnan(sna_float) && npy_isnan(ona_float)) {
544+
eq = 1;
545+
}
546+
547+
// finally check if a python equals comparison returns True
548+
else if (PyObject_RichCompareBool(sna, ona, Py_EQ) == 1) {
549+
eq = 1;
550+
}
551+
else {
552+
eq = 0;
553+
}
554+
}
524555

525556
PyObject *ret = Py_NotImplemented;
526557
if ((op == Py_EQ && eq) || (op == Py_NE && !eq)) {

stringdtype/tests/test_stringdtype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ def coerce(request):
3232

3333

3434
@pytest.fixture(
35-
params=["unset", None, pd_param], ids=["unset", "None", "pandas.NA"]
35+
params=["unset", None, pd_param, np.nan, float("nan")],
36+
ids=["unset", "None", "pandas.NA", "np.nan", "float('nan')"],
3637
)
3738
def na_object(request):
3839
return request.param
3940

4041

4142
@pytest.fixture()
4243
def dtype(na_object, coerce):
44+
# explicit is check for pd_NA because != with pd_NA returns pd_NA
4345
if na_object is pd_NA or na_object != "unset":
4446
return StringDType(na_object=na_object, coerce=coerce)
4547
else:

0 commit comments

Comments
 (0)