Skip to content

Commit ddce4d3

Browse files
committed
flesh out NULL handling
1 parent a7724a1 commit ddce4d3

File tree

5 files changed

+511
-187
lines changed

5 files changed

+511
-187
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
360360
npy_intp const dimensions[], npy_intp const strides[],
361361
NpyAuxData *auxdata)
362362
{
363+
StringDTypeObject *descr = (StringDTypeObject *)context->descriptors[0];
364+
int has_null = descr->na_object != NULL;
365+
int has_string_na = descr->has_string_na;
366+
ss default_string = descr->default_string;
363367
npy_intp N = dimensions[0];
364368
char *in = data[0];
365369
Py_UCS4 *out = (Py_UCS4 *)data[1];
@@ -376,9 +380,15 @@ string_to_unicode(PyArrayMethod_Context *context, char *const data[],
376380
unsigned char *this_string = NULL;
377381
size_t n_bytes;
378382
if (ss_isnull(s)) {
379-
// lossy but not much else we can do
380-
this_string = (unsigned char *)((s2u_auxdata *)auxdata)->na_name;
381-
n_bytes = ((s2u_auxdata *)auxdata)->len;
383+
if (has_null && !has_string_na) {
384+
// lossy but not much else we can do
385+
this_string =
386+
(unsigned char *)((s2u_auxdata *)auxdata)->na_name;
387+
n_bytes = ((s2u_auxdata *)auxdata)->len;
388+
}
389+
else {
390+
this_string = (unsigned char *)(default_string.buf);
391+
}
382392
}
383393
else {
384394
this_string = (unsigned char *)(s->buf);
@@ -470,8 +480,11 @@ string_to_bool(PyArrayMethod_Context *context, char *const data[],
470480
npy_intp const dimensions[], npy_intp const strides[],
471481
NpyAuxData *NPY_UNUSED(auxdata))
472482
{
473-
int hasnull = (((StringDTypeObject *)context->descriptors[0])->na_object !=
474-
NULL);
483+
StringDTypeObject *descr = (StringDTypeObject *)context->descriptors[0];
484+
int has_null = descr->na_object != NULL;
485+
int has_string_na = descr->has_string_na;
486+
ss default_string = descr->default_string;
487+
475488
npy_intp N = dimensions[0];
476489
char *in = data[0];
477490
char *out = data[1];
@@ -484,13 +497,12 @@ string_to_bool(PyArrayMethod_Context *context, char *const data[],
484497
while (N--) {
485498
s = (ss *)in;
486499
if (ss_isnull(s)) {
487-
if (hasnull) {
500+
if (has_null && !has_string_na) {
488501
// numpy treats NaN as truthy, following python
489502
*out = (npy_bool)1;
490503
}
491504
else {
492-
// empty string is falsey
493-
*out = (npy_bool)0;
505+
*out = (npy_bool)(default_string.len == 0);
494506
}
495507
}
496508
else if (s->len == 0) {
@@ -1016,8 +1028,11 @@ string_to_datetime(PyArrayMethod_Context *context, char *const data[],
10161028
npy_intp const dimensions[], npy_intp const strides[],
10171029
NpyAuxData *NPY_UNUSED(auxdata))
10181030
{
1019-
int hasnull = (((StringDTypeObject *)context->descriptors[0])->na_object !=
1020-
NULL);
1031+
StringDTypeObject *descr = (StringDTypeObject *)context->descriptors[0];
1032+
int has_null = descr->na_object != NULL;
1033+
int has_string_na = descr->has_string_na;
1034+
ss default_string = descr->default_string;
1035+
10211036
npy_intp N = dimensions[0];
10221037
char *in = data[0];
10231038
npy_datetime *out = (npy_datetime *)data[1];
@@ -1038,11 +1053,11 @@ string_to_datetime(PyArrayMethod_Context *context, char *const data[],
10381053
while (N--) {
10391054
s = (ss *)in;
10401055
if (ss_isnull(s)) {
1041-
if (hasnull) {
1056+
if (has_null && !has_string_na) {
10421057
*out = NPY_DATETIME_NAT;
10431058
goto next_step;
10441059
}
1045-
s = &EMPTY_STRING;
1060+
s = &default_string;
10461061
}
10471062
if (NpyDatetime_ParseISO8601Datetime(
10481063
(const char *)s->buf, s->len, in_unit, NPY_UNSAFE_CASTING,

stringdtype/stringdtype/src/dtype.c

Lines changed: 155 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
2020

2121
Py_XINCREF(na_object);
2222
((StringDTypeObject *)new)->na_object = na_object;
23+
int hasnull = na_object != NULL;
24+
int has_nan_na = 0;
25+
int has_string_na = 0;
26+
ss default_string = EMPTY_STRING;
27+
if (hasnull) {
28+
double na_float = PyFloat_AsDouble(na_object);
29+
if (na_float == -1.0 && PyErr_Occurred()) {
30+
// not a float, still treat as nan if PyObject_IsTrue raises
31+
// (e.g. pandas.NA)
32+
PyErr_Clear();
33+
int is_truthy = PyObject_IsTrue(na_object);
34+
if (is_truthy == -1) {
35+
PyErr_Clear();
36+
has_nan_na = 1;
37+
}
38+
}
39+
else if (npy_isnan(na_float)) {
40+
has_nan_na = 1;
41+
}
42+
43+
if (PyUnicode_Check(na_object)) {
44+
has_string_na = 1;
45+
Py_ssize_t size = 0;
46+
const char *buf = PyUnicode_AsUTF8AndSize(na_object, &size);
47+
default_string.len = size;
48+
// discards const, how to avoid?
49+
default_string.buf = (char *)buf;
50+
}
51+
}
52+
((StringDTypeObject *)new)->has_nan_na = has_nan_na;
53+
((StringDTypeObject *)new)->has_string_na = has_string_na;
54+
((StringDTypeObject *)new)->default_string = default_string;
2355
((StringDTypeObject *)new)->coerce = coerce;
2456

2557
PyArray_Descr *base = (PyArray_Descr *)new;
@@ -28,6 +60,9 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
2860
base->flags |= NPY_NEEDS_INIT;
2961
base->flags |= NPY_LIST_PICKLE;
3062
base->flags |= NPY_ITEM_REFCOUNT;
63+
if (hasnull && !(has_string_na && has_nan_na)) {
64+
base->flags |= NPY_NEEDS_PYAPI;
65+
}
3166

3267
return new;
3368
}
@@ -227,25 +262,43 @@ int
227262
_compare(void *a, void *b, StringDTypeObject *descr)
228263
{
229264
int hasnull = descr->na_object != NULL;
265+
int has_string_na = descr->has_string_na;
266+
int has_nan_na = descr->has_nan_na;
267+
if (hasnull && !(has_string_na && has_nan_na)) {
268+
// check if an error occured already to avoid setting an error again
269+
if (PyErr_Occurred()) {
270+
return 0;
271+
}
272+
}
273+
const ss *default_string = &descr->default_string;
230274
const ss *ss_a = (ss *)a;
231275
const ss *ss_b = (ss *)b;
232276
int a_is_null = ss_isnull(ss_a);
233277
int b_is_null = ss_isnull(ss_b);
234278
if (NPY_UNLIKELY(a_is_null || b_is_null)) {
235-
if (hasnull) {
236-
if (a_is_null) {
237-
return 1;
279+
if (hasnull && !has_string_na) {
280+
if (has_nan_na) {
281+
if (a_is_null) {
282+
return 1;
283+
}
284+
else if (b_is_null) {
285+
return -1;
286+
}
238287
}
239-
else if (b_is_null) {
240-
return -1;
288+
else {
289+
// we must hold the GIL in this branch
290+
PyErr_SetString(
291+
PyExc_ValueError,
292+
"Cannot compare null this is not a nan-like value");
293+
return 0;
241294
}
242295
}
243296
else {
244297
if (a_is_null) {
245-
ss_a = &EMPTY_STRING;
298+
ss_a = default_string;
246299
}
247300
if (b_is_null) {
248-
ss_b = &EMPTY_STRING;
301+
ss_b = default_string;
249302
}
250303
}
251304
}
@@ -349,6 +402,94 @@ stringdtype_get_fill_zero_loop(void *NPY_UNUSED(traverse_context),
349402
return 0;
350403
}
351404

405+
static int
406+
stringdtype_is_known_scalar_type(PyArray_DTypeMeta *NPY_UNUSED(cls),
407+
PyTypeObject *pytype)
408+
{
409+
if (pytype == &PyFloat_Type) {
410+
return 1;
411+
}
412+
if (pytype == &PyLong_Type) {
413+
return 1;
414+
}
415+
if (pytype == &PyBool_Type) {
416+
return 1;
417+
}
418+
if (pytype == &PyComplex_Type) {
419+
return 1;
420+
}
421+
if (pytype == &PyUnicode_Type) {
422+
return 1;
423+
}
424+
if (pytype == &PyBytes_Type) {
425+
return 1;
426+
}
427+
if (pytype == &PyBoolArrType_Type) {
428+
return 1;
429+
}
430+
if (pytype == &PyByteArrType_Type) {
431+
return 1;
432+
}
433+
if (pytype == &PyShortArrType_Type) {
434+
return 1;
435+
}
436+
if (pytype == &PyIntArrType_Type) {
437+
return 1;
438+
}
439+
if (pytype == &PyLongArrType_Type) {
440+
return 1;
441+
}
442+
if (pytype == &PyLongLongArrType_Type) {
443+
return 1;
444+
}
445+
if (pytype == &PyUByteArrType_Type) {
446+
return 1;
447+
}
448+
if (pytype == &PyUShortArrType_Type) {
449+
return 1;
450+
}
451+
if (pytype == &PyUIntArrType_Type) {
452+
return 1;
453+
}
454+
if (pytype == &PyULongArrType_Type) {
455+
return 1;
456+
}
457+
if (pytype == &PyULongLongArrType_Type) {
458+
return 1;
459+
}
460+
if (pytype == &PyHalfArrType_Type) {
461+
return 1;
462+
}
463+
if (pytype == &PyFloatArrType_Type) {
464+
return 1;
465+
}
466+
if (pytype == &PyDoubleArrType_Type) {
467+
return 1;
468+
}
469+
if (pytype == &PyLongDoubleArrType_Type) {
470+
return 1;
471+
}
472+
if (pytype == &PyCFloatArrType_Type) {
473+
return 1;
474+
}
475+
if (pytype == &PyCDoubleArrType_Type) {
476+
return 1;
477+
}
478+
if (pytype == &PyCLongDoubleArrType_Type) {
479+
return 1;
480+
}
481+
if (pytype == &PyIntpArrType_Type) {
482+
return 1;
483+
}
484+
if (pytype == &PyUIntpArrType_Type) {
485+
return 1;
486+
}
487+
if (pytype == &PyDatetimeArrType_Type) {
488+
return 1;
489+
}
490+
return 0;
491+
}
492+
352493
static PyType_Slot StringDType_Slots[] = {
353494
{NPY_DT_common_instance, &common_instance},
354495
{NPY_DT_common_dtype, &common_dtype},
@@ -363,6 +504,7 @@ static PyType_Slot StringDType_Slots[] = {
363504
{NPY_DT_PyArray_ArrFuncs_argmin, &argmin},
364505
{NPY_DT_get_clear_loop, &stringdtype_get_clear_loop},
365506
{NPY_DT_get_fill_zero_loop, &stringdtype_get_fill_zero_loop},
507+
{_NPY_DT_is_known_scalar_type, &stringdtype_is_known_scalar_type},
366508
{0, NULL}};
367509

368510
static PyObject *
@@ -530,7 +672,7 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
530672
// pointer equality catches pandas.NA and other NA singletons
531673
eq = 1;
532674
}
533-
else {
675+
else if (PyFloat_Check(sna) && PyFloat_Check(ona)) {
534676
// nan check catches np.nan and float('nan')
535677
double sna_float = PyFloat_AsDouble(sna);
536678
if (sna_float == -1.0 && PyErr_Occurred()) {
@@ -543,13 +685,12 @@ StringDType_richcompare(PyObject *self, PyObject *other, int op)
543685
if (npy_isnan(sna_float) && npy_isnan(ona_float)) {
544686
eq = 1;
545687
}
546-
688+
}
689+
else {
547690
// finally check if a python equals comparison returns True
548-
else if (PyObject_RichCompareBool(sna, ona, Py_EQ) == 1) {
549-
eq = 1;
550-
}
551-
else {
552-
eq = 0;
691+
eq = PyObject_RichCompareBool(sna, ona, Py_EQ);
692+
if (eq == -1) {
693+
return NULL;
553694
}
554695
}
555696

stringdtype/stringdtype/src/dtype.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include "structmember.h"
77
// clang-format on
88

9+
#include "static_string.h"
10+
911
#define PY_ARRAY_UNIQUE_SYMBOL stringdtype_ARRAY_API
1012
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
1113
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
@@ -21,6 +23,9 @@ typedef struct {
2123
PyArray_Descr base;
2224
PyObject *na_object;
2325
int coerce;
26+
int has_nan_na;
27+
int has_string_na;
28+
ss default_string;
2429
} StringDTypeObject;
2530

2631
typedef struct {

0 commit comments

Comments
 (0)