Skip to content

Commit d85d995

Browse files
committed
fix complex case
1 parent 62511e6 commit d85d995

File tree

2 files changed

+65
-8
lines changed

2 files changed

+65
-8
lines changed

pandas/_libs/src/klib/khash_python.h

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,50 @@ KHASH_MAP_INIT_COMPLEX128(complex128, size_t)
163163
#define kh_exist_complex128(h, k) (kh_exist(h, k))
164164

165165

166+
int PANDAS_INLINE floatobject_cmp(PyObject* a, PyObject* b){
167+
return Py_IS_NAN(PyFloat_AS_DOUBLE(a)) &&
168+
Py_IS_NAN(PyFloat_AS_DOUBLE(b));
169+
}
170+
171+
172+
int PANDAS_INLINE complexobject_cmp(PyObject* a, PyObject* b){
173+
return (
174+
Py_IS_NAN(PyComplex_RealAsDouble(a)) &&
175+
Py_IS_NAN(PyComplex_RealAsDouble(b)) &&
176+
Py_IS_NAN(PyComplex_ImagAsDouble(a)) &&
177+
Py_IS_NAN(PyComplex_ImagAsDouble(b))
178+
)
179+
||
180+
(
181+
Py_IS_NAN(PyComplex_RealAsDouble(a)) &&
182+
Py_IS_NAN(PyComplex_RealAsDouble(b)) &&
183+
PyComplex_ImagAsDouble(a) == PyComplex_ImagAsDouble(b)
184+
)
185+
||
186+
(
187+
PyComplex_RealAsDouble(a) == PyComplex_RealAsDouble(b) &&
188+
Py_IS_NAN(PyComplex_ImagAsDouble(a)) &&
189+
Py_IS_NAN(PyComplex_ImagAsDouble(b))
190+
);
191+
}
192+
193+
166194
int PANDAS_INLINE pyobject_cmp(PyObject* a, PyObject* b) {
167195
int result = PyObject_RichCompareBool(a, b, Py_EQ);
168196
if (result < 0) {
169197
PyErr_Clear();
170198
return 0;
171199
}
172-
if (result == 0) { // still could be two NaNs
173-
return PyFloat_CheckExact(a) &&
174-
PyFloat_CheckExact(b) &&
175-
Py_IS_NAN(PyFloat_AS_DOUBLE(a)) &&
176-
Py_IS_NAN(PyFloat_AS_DOUBLE(b));
200+
if (result == 0) { // still could be built-ins with NaNs
201+
if (Py_TYPE(a) != Py_TYPE(b)) {
202+
return 0;
203+
}
204+
if (PyFloat_CheckExact(a)) {
205+
return floatobject_cmp(a, b);
206+
}
207+
if (PyComplex_CheckExact(a)) {
208+
return complexobject_cmp(a, b);
209+
}
177210
}
178211
return result;
179212
}

pandas/tests/libs/test_hashtable.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,38 @@ def test_nan_float(self):
187187
table.set_item(nan1, 42)
188188
assert table.get_item(nan2) == 42
189189

190-
def test_nan_complex(self):
191-
nan1 = 1j * float("nan")
192-
nan2 = 1j * float("nan")
190+
def test_nan_complex_both(self):
191+
nan1 = complex(float("nan"), float("nan"))
192+
nan2 = complex(float("nan"), float("nan"))
193193
assert nan1 is not nan2
194194
table = ht.PyObjectHashTable()
195195
table.set_item(nan1, 42)
196196
assert table.get_item(nan2) == 42
197197

198+
def test_nan_complex_real(self):
199+
nan1 = complex(float("nan"), 1)
200+
nan2 = complex(float("nan"), 1)
201+
other = complex(float("nan"), 2)
202+
assert nan1 is not nan2
203+
table = ht.PyObjectHashTable()
204+
table.set_item(nan1, 42)
205+
assert table.get_item(nan2) == 42
206+
with pytest.raises(KeyError, match=None) as error:
207+
table.get_item(other)
208+
assert str(error.value) == str(other)
209+
210+
def test_nan_complex_imag(self):
211+
nan1 = complex(1, float("nan"))
212+
nan2 = complex(1, float("nan"))
213+
other = complex(2, float("nan"))
214+
assert nan1 is not nan2
215+
table = ht.PyObjectHashTable()
216+
table.set_item(nan1, 42)
217+
assert table.get_item(nan2) == 42
218+
with pytest.raises(KeyError, match=None) as error:
219+
table.get_item(other)
220+
assert str(error.value) == str(other)
221+
198222

199223
def test_get_labels_groupby_for_Int64(writable):
200224
table = ht.Int64HashTable()

0 commit comments

Comments
 (0)