Skip to content

Commit fc08ab6

Browse files
committed
Make it possible for users to use a custom NA value
1 parent 6537f42 commit fc08ab6

File tree

6 files changed

+109
-56
lines changed

6 files changed

+109
-56
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
2222
npy_intp *view_offset)
2323
{
2424
if (given_descrs[1] == NULL) {
25-
StringDTypeObject *new = new_stringdtype_instance();
26-
if (new == NULL) {
27-
return (NPY_CASTING)-1;
28-
}
29-
loop_descrs[1] = (PyArray_Descr *)new;
25+
Py_INCREF(given_descrs[0]);
26+
loop_descrs[1] = given_descrs[0];
3027
}
3128
else {
3229
Py_INCREF(given_descrs[1]);
@@ -89,7 +86,7 @@ unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
8986
npy_intp *NPY_UNUSED(view_offset))
9087
{
9188
if (given_descrs[1] == NULL) {
92-
StringDTypeObject *new = new_stringdtype_instance();
89+
StringDTypeObject *new = new_stringdtype_instance(NA_OBJ);
9390
if (new == NULL) {
9491
return (NPY_CASTING)-1;
9592
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55

66
PyTypeObject *StringScalar_Type = NULL;
77
static PyTypeObject *StringNA_Type = NULL;
8-
static PyObject *NA_OBJ = NULL;
8+
PyObject *NA_OBJ = NULL;
99

1010
/*
1111
* Internal helper to create new instances
1212
*/
1313
StringDTypeObject *
14-
new_stringdtype_instance(void)
14+
new_stringdtype_instance(PyObject *na_object)
1515
{
1616
StringDTypeObject *new = (StringDTypeObject *)PyArrayDescr_Type.tp_new(
1717
(PyTypeObject *)&StringDType, NULL, NULL);
1818
if (new == NULL) {
1919
return NULL;
2020
}
21+
Py_INCREF(na_object);
22+
new->na_object = na_object;
2123
new->base.elsize = sizeof(ss);
2224
new->base.alignment = _Alignof(ss);
2325
new->base.flags |= NPY_NEEDS_INIT;
@@ -72,15 +74,15 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
7274
return NULL;
7375
}
7476

75-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
77+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ);
7678
if (ret == NULL) {
7779
return NULL;
7880
}
7981
return ret;
8082
}
8183

8284
static PyObject *
83-
get_value(PyObject *scalar)
85+
get_value(PyObject *scalar, PyObject *na_object)
8486
{
8587
PyObject *ret = NULL;
8688
PyTypeObject *scalar_type = Py_TYPE(scalar);
@@ -96,7 +98,7 @@ get_value(PyObject *scalar)
9698
return NULL;
9799
}
98100
}
99-
else if (scalar_type == StringNA_Type) {
101+
else if (scalar == na_object) {
100102
ret = scalar;
101103
Py_INCREF(ret);
102104
}
@@ -107,7 +109,7 @@ get_value(PyObject *scalar)
107109
return NULL;
108110
}
109111
if (npy_isnan(scalar_val)) {
110-
ret = NA_OBJ;
112+
ret = na_object;
111113
Py_INCREF(ret);
112114
}
113115
else {
@@ -128,10 +130,9 @@ get_value(PyObject *scalar)
128130
// Take a python object `obj` and insert it into the array of dtype `descr` at
129131
// the position given by dataptr.
130132
static int
131-
stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
132-
char **dataptr)
133+
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
133134
{
134-
PyObject *val_obj = get_value(obj);
135+
PyObject *val_obj = get_value(obj, descr->na_object);
135136

136137
if (val_obj == NULL) {
137138
return -1;
@@ -143,15 +144,9 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
143144
// ssfree does a NULL check
144145
ssfree(sdata);
145146

146-
// RichCompareBool short-circuits to a pointer comparison fast-path
147-
// so no need to do pointer comparison first
148-
int eq_res = PyObject_RichCompareBool(val_obj, NA_OBJ, Py_EQ);
149-
150-
if (eq_res < 0) {
151-
goto error;
152-
}
153-
154-
if (eq_res == 1) {
147+
// setting NA *must* check pointer equality since NA types might not
148+
// allow equality
149+
if (val_obj == descr->na_object) {
155150
// do nothing, ssfree already NULLed the struct ssdata points to
156151
// so it already contains a NA value
157152
}
@@ -185,14 +180,14 @@ stringdtype_setitem(StringDTypeObject *NPY_UNUSED(descr), PyObject *obj,
185180
}
186181

187182
static PyObject *
188-
stringdtype_getitem(StringDTypeObject *NPY_UNUSED(descr), char **dataptr)
183+
stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
189184
{
190185
PyObject *val_obj = NULL;
191186
ss *sdata = (ss *)dataptr;
192187

193188
if (ss_isnull(sdata)) {
194-
Py_INCREF(NA_OBJ);
195-
val_obj = NA_OBJ;
189+
Py_INCREF(descr->na_object);
190+
val_obj = descr->na_object;
196191
}
197192
else {
198193
char *data = sdata->buf;
@@ -359,28 +354,48 @@ static PyType_Slot StringDType_Slots[] = {
359354
static PyObject *
360355
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
361356
{
362-
static char *kwargs_strs[] = {"size", NULL};
357+
static char *kwargs_strs[] = {"size", "na_object", NULL};
363358

364359
long size = 0;
360+
PyObject *na_object = NULL;
365361

366-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:StringDType", kwargs_strs,
367-
&size)) {
362+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lO:StringDType",
363+
kwargs_strs, &size, &na_object)) {
368364
return NULL;
369365
}
370366

371-
return (PyObject *)new_stringdtype_instance();
367+
if (na_object == NULL) {
368+
na_object = NA_OBJ;
369+
}
370+
371+
Py_INCREF(na_object);
372+
373+
PyObject *ret = (PyObject *)new_stringdtype_instance(na_object);
374+
375+
Py_DECREF(na_object);
376+
377+
return ret;
372378
}
373379

374380
static void
375381
stringdtype_dealloc(StringDTypeObject *self)
376382
{
383+
Py_DECREF(self->na_object);
377384
PyArrayDescr_Type.tp_dealloc((PyObject *)self);
378385
}
379386

380387
static PyObject *
381-
stringdtype_repr(StringDTypeObject *NPY_UNUSED(self))
388+
stringdtype_repr(StringDTypeObject *self)
382389
{
383-
return PyUnicode_FromString("StringDType()");
390+
PyObject *ret = NULL;
391+
if (self->na_object != NA_OBJ) {
392+
ret = PyUnicode_FromFormat("StringDType(na_object=%R)",
393+
self->na_object);
394+
}
395+
else {
396+
ret = PyUnicode_FromString("StringDType()");
397+
}
398+
return ret;
384399
}
385400

386401
static int PICKLE_VERSION = 1;
@@ -485,6 +500,7 @@ init_string_dtype(void)
485500
PyArrayMethod_Spec **casts = get_casts();
486501

487502
PyArrayDTypeMeta_Spec StringDType_DTypeSpec = {
503+
.flags = NPY_DT_PARAMETRIC,
488504
.typeobj = StringScalar_Type,
489505
.slots = StringDType_Slots,
490506
.casts = casts,

stringdtype/stringdtype/src/dtype.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
typedef struct {
1818
PyArray_Descr base;
19+
PyObject *na_object;
1920
} StringDTypeObject;
2021

2122
extern PyArray_DTypeMeta StringDType;
2223
extern PyTypeObject *StringScalar_Type;
24+
extern PyObject *NA_OBJ;
2325

2426
StringDTypeObject *
25-
new_stringdtype_instance(void);
27+
new_stringdtype_instance(PyObject *na_object);
2628

2729
int
2830
init_string_dtype(void);
@@ -33,7 +35,6 @@ compare(void *, void *, void *);
3335
int
3436
init_string_na_object(PyObject *mod);
3537

36-
3738
// from dtypemeta.h, not public in numpy
3839
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
3940

stringdtype/stringdtype/src/main.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ PyInit__main(void)
113113
goto error;
114114
}
115115

116-
if (init_string_dtype() < 0) {
116+
if (init_string_na_object(mod) < 0) {
117117
goto error;
118118
}
119119

120-
if (init_string_na_object(mod) < 0) {
120+
if (init_string_dtype() < 0) {
121121
goto error;
122122
}
123123

stringdtype/stringdtype/src/umath.c

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,39 @@
1313
#include "string.h"
1414
#include "umath.h"
1515

16+
static NPY_CASTING
17+
binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
19+
PyArray_Descr *given_descrs[],
20+
PyArray_Descr *loop_descrs[],
21+
npy_intp *NPY_UNUSED(view_offset))
22+
{
23+
PyObject *na_obj1 = ((StringDTypeObject *)given_descrs[0])->na_object;
24+
PyObject *na_obj2 = ((StringDTypeObject *)given_descrs[1])->na_object;
25+
26+
int eq_res = PyObject_RichCompareBool(na_obj1, na_obj2, Py_EQ);
27+
28+
if (eq_res < 0) {
29+
return (NPY_CASTING)-1;
30+
}
31+
32+
if (eq_res != 1) {
33+
PyErr_SetString(PyExc_TypeError,
34+
"Can only do binary operations with identical "
35+
"StringDType instances.");
36+
return (NPY_CASTING)-1;
37+
}
38+
39+
Py_INCREF(given_descrs[0]);
40+
loop_descrs[0] = given_descrs[0];
41+
Py_INCREF(given_descrs[1]);
42+
loop_descrs[1] = given_descrs[1];
43+
Py_INCREF(given_descrs[1]);
44+
loop_descrs[2] = given_descrs[1];
45+
46+
return NPY_NO_CASTING;
47+
}
48+
1649
static int
1750
add_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
1851
char *const data[], npy_intp const dimensions[],
@@ -27,9 +60,9 @@ add_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
2760
npy_intp out_stride = strides[2];
2861

2962
ss *s1 = NULL, *s2 = NULL, *os = NULL;
30-
int newlen = 0;
3163

3264
while (N--) {
65+
int newlen = 0;
3366
s1 = (ss *)in1;
3467
s2 = (ss *)in2;
3568
os = (ss *)out;
@@ -312,16 +345,10 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
312345
.dtypes = dtypes,
313346
};
314347

315-
if (resolve_func == NULL) {
316-
PyType_Slot slots[] = {{NPY_METH_strided_loop, loop_func}, {0, NULL}};
317-
spec.slots = slots;
318-
}
319-
else {
320-
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
321-
{NPY_METH_strided_loop, loop_func},
322-
{0, NULL}};
323-
spec.slots = slots;
324-
}
348+
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, resolve_func},
349+
{NPY_METH_strided_loop, loop_func},
350+
{0, NULL}};
351+
spec.slots = slots;
325352

326353
if (PyUFunc_AddLoopFromSpec(ufunc, &spec) < 0) {
327354
Py_DECREF(ufunc);
@@ -406,21 +433,22 @@ init_ufuncs(void)
406433

407434
PyArray_DTypeMeta *minmax_dtypes[] = {&StringDType, &StringDType,
408435
&StringDType};
409-
if (init_ufunc(numpy, "maximum", minmax_dtypes, NULL,
410-
&maximum_strided_loop, "string_maximum", 2, 1,
411-
NPY_NO_CASTING, 0) < 0) {
436+
if (init_ufunc(numpy, "maximum", minmax_dtypes,
437+
&binary_resolve_descriptors, &maximum_strided_loop,
438+
"string_maximum", 2, 1, NPY_NO_CASTING, 0) < 0) {
412439
goto error;
413440
}
414-
if (init_ufunc(numpy, "minimum", minmax_dtypes, NULL,
415-
&minimum_strided_loop, "string_minimum", 2, 1,
416-
NPY_NO_CASTING, 0) < 0) {
441+
if (init_ufunc(numpy, "minimum", minmax_dtypes,
442+
&binary_resolve_descriptors, &minimum_strided_loop,
443+
"string_minimum", 2, 1, NPY_NO_CASTING, 0) < 0) {
417444
goto error;
418445
}
419446

420447
PyArray_DTypeMeta *add_types[] = {&StringDType, &StringDType,
421448
&StringDType};
422-
if (init_ufunc(numpy, "add", add_types, NULL, &add_strided_loop,
423-
"string_add", 2, 1, NPY_NO_CASTING, 0) < 0) {
449+
if (init_ufunc(numpy, "add", add_types, &binary_resolve_descriptors,
450+
&add_strided_loop, "string_add", 2, 1, NPY_NO_CASTING,
451+
0) < 0) {
424452
goto error;
425453
}
426454

stringdtype/tests/test_stringdtype.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,14 @@ def test_create_with_na(na_val):
338338
== "array(['hello', stringdtype.NA, 'world'], dtype=StringDType())"
339339
)
340340
assert arr[1] == NA and arr[1] is NA
341+
342+
343+
def test_custom_na():
344+
dtype = StringDType(na_object=None)
345+
string_list = ["hello", None, "world"]
346+
arr = np.array(string_list, dtype=dtype)
347+
assert (
348+
repr(arr)
349+
== "array(['hello', None, 'world'], dtype=StringDType(na_object=None))"
350+
)
351+
assert arr[1] is None

0 commit comments

Comments
 (0)