Skip to content

Commit dcca859

Browse files
committed
refactor so the na object lives in the class' tp_dict
1 parent 99c94d7 commit dcca859

File tree

4 files changed

+25
-39
lines changed

4 files changed

+25
-39
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
8686
npy_intp *NPY_UNUSED(view_offset))
8787
{
8888
if (given_descrs[1] == NULL) {
89-
StringDTypeObject *new = new_stringdtype_instance(NA_OBJ);
89+
StringDTypeObject *new = new_stringdtype_instance();
9090
if (new == NULL) {
9191
return (NPY_CASTING)-1;
9292
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@ PyObject *NA_OBJ = NULL;
1111
* Internal helper to create new instances
1212
*/
1313
StringDTypeObject *
14-
new_stringdtype_instance(PyObject *na_object)
14+
new_stringdtype_instance(void)
1515
{
1616
StringDTypeObject *new = (StringDTypeObject *)PyArrayDescr_Type.tp_new(
1717
(PyTypeObject *)&StringDType, NULL, NULL);
1818
if (new == NULL) {
1919
return NULL;
2020
}
21-
Py_INCREF(na_object);
22-
new->na_object = na_object;
2321
new->base.elsize = sizeof(ss);
2422
new->base.alignment = _Alignof(ss);
2523
new->base.flags |= NPY_NEEDS_INIT;
@@ -74,7 +72,7 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
7472
return NULL;
7573
}
7674

77-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(NA_OBJ);
75+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
7876
if (ret == NULL) {
7977
return NULL;
8078
}
@@ -132,7 +130,9 @@ get_value(PyObject *scalar, PyObject *na_object)
132130
static int
133131
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
134132
{
135-
PyObject *val_obj = get_value(obj, descr->na_object);
133+
PyObject *na_object =
134+
PyDict_GetItemString(Py_TYPE(descr)->tp_dict, "na_object");
135+
PyObject *val_obj = get_value(obj, na_object);
136136

137137
if (val_obj == NULL) {
138138
return -1;
@@ -146,7 +146,7 @@ stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
146146

147147
// setting NA *must* check pointer equality since NA types might not
148148
// allow equality
149-
if (val_obj == descr->na_object) {
149+
if (val_obj == na_object) {
150150
// do nothing, ssfree already NULLed the struct ssdata points to
151151
// so it already contains a NA value
152152
}
@@ -186,8 +186,10 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
186186
ss *sdata = (ss *)dataptr;
187187

188188
if (ss_isnull(sdata)) {
189-
Py_INCREF(descr->na_object);
190-
val_obj = descr->na_object;
189+
PyObject *na_object =
190+
PyDict_GetItemString(Py_TYPE(descr)->tp_dict, "na_object");
191+
Py_INCREF(na_object);
192+
val_obj = na_object;
191193
}
192194
else {
193195
char *data = sdata->buf;
@@ -354,43 +356,35 @@ static PyType_Slot StringDType_Slots[] = {
354356
static PyObject *
355357
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
356358
{
357-
static char *kwargs_strs[] = {"size", "na_object", NULL};
359+
static char *kwargs_strs[] = {"size", NULL};
358360

359361
long size = 0;
360-
PyObject *na_object = NULL;
361362

362-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|lO:StringDType",
363-
kwargs_strs, &size, &na_object)) {
363+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:StringDType", kwargs_strs,
364+
&size)) {
364365
return NULL;
365366
}
366367

367-
if (na_object == NULL) {
368-
na_object = NA_OBJ;
369-
}
370-
371-
Py_INCREF(na_object);
372-
373-
PyObject *ret = (PyObject *)new_stringdtype_instance(na_object);
374-
375-
Py_DECREF(na_object);
368+
PyObject *ret = (PyObject *)new_stringdtype_instance();
376369

377370
return ret;
378371
}
379372

380373
static void
381374
stringdtype_dealloc(StringDTypeObject *self)
382375
{
383-
Py_DECREF(self->na_object);
384376
PyArrayDescr_Type.tp_dealloc((PyObject *)self);
385377
}
386378

387379
static PyObject *
388380
stringdtype_repr(StringDTypeObject *self)
389381
{
390382
PyObject *ret = NULL;
391-
if (self->na_object != NA_OBJ) {
392-
ret = PyUnicode_FromFormat("StringDType(na_object=%R)",
393-
self->na_object);
383+
PyObject *na_object =
384+
PyDict_GetItemString(Py_TYPE(self)->tp_dict, "na_object");
385+
386+
if (na_object != NA_OBJ) {
387+
ret = PyUnicode_FromString("PandasStringDType()");
394388
}
395389
else {
396390
ret = PyUnicode_FromString("StringDType()");
@@ -471,7 +465,7 @@ static PyMethodDef StringDType_methods[] = {
471465
METH_O,
472466
"Unpickle an StringDType object",
473467
},
474-
{NULL},
468+
{NULL, NULL, 0, NULL},
475469
};
476470

477471
/*
@@ -509,6 +503,9 @@ init_string_dtype(void)
509503
/* Loaded dynamically, so may need to be set here: */
510504
((PyObject *)&StringDType)->ob_type = &PyArrayDTypeMeta_Type;
511505
((PyTypeObject *)&StringDType)->tp_base = &PyArrayDescr_Type;
506+
((PyTypeObject *)&StringDType)->tp_dict = PyDict_New();
507+
PyDict_SetItemString(((PyTypeObject *)&StringDType)->tp_dict, "na_object",
508+
NA_OBJ);
512509
if (PyType_Ready((PyTypeObject *)&StringDType) < 0) {
513510
return -1;
514511
}

stringdtype/stringdtype/src/dtype.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ extern PyTypeObject *StringScalar_Type;
2424
extern PyObject *NA_OBJ;
2525

2626
StringDTypeObject *
27-
new_stringdtype_instance(PyObject *na_object);
27+
new_stringdtype_instance(void);
2828

2929
int
3030
init_string_dtype(void);

stringdtype/tests/test_stringdtype.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,3 @@ def test_create_with_na(na_val):
338338
== "array(['hello', stringdtype.NA, 'world'], dtype=StringDType())"
339339
)
340340
assert arr[1] == NA and arr[1] is NA
341-
342-
343-
def test_custom_na():
344-
dtype = StringDType(na_object=None)
345-
string_list = ["hello", None, "world"]
346-
arr = np.array(string_list, dtype=dtype)
347-
assert (
348-
repr(arr)
349-
== "array(['hello', None, 'world'], dtype=StringDType(na_object=None))"
350-
)
351-
assert arr[1] is None

0 commit comments

Comments
 (0)