Skip to content

Commit 1f1e5b9

Browse files
committed
Add static PandasStringDType that uses pandas NA object
1 parent c50a701 commit 1f1e5b9

File tree

6 files changed

+113
-4
lines changed

6 files changed

+113
-4
lines changed

stringdtype/stringdtype/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
"""
44

55
from .missing import NA # isort: skip
6-
from .scalar import StringScalar # isort: skip
7-
from ._main import StringDType, _memory_usage
6+
from .scalar import StringScalar, PandasStringScalar # isort: skip
7+
from ._main import PandasStringDType, StringDType, _memory_usage
88

99
__all__ = [
1010
"NA",
1111
"StringDType",
1212
"StringScalar",
1313
"_memory_usage",
1414
]
15+
16+
# this happens when pandas isn't importable
17+
if StringDType is PandasStringDType:
18+
del PandasStringDType
19+
else:
20+
__all__.extend("PandasStringDType")

stringdtype/stringdtype/scalar.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
"""A scalar type needed by the dtype machinery."""
1+
"""Scalar types needed by the dtype machinery."""
22

33

44
class StringScalar(str):
55
pass
6+
7+
8+
class PandasStringScalar(str):
9+
pass

stringdtype/stringdtype/src/dtype.c

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "static_string.h"
55

66
PyTypeObject *StringScalar_Type = NULL;
7+
PyTypeObject *PandasStringScalar_Type = NULL;
78
static PyTypeObject *StringNA_Type = NULL;
89
PyObject *NA_OBJ = NULL;
910

@@ -488,6 +489,30 @@ PyArray_DTypeMeta StringDType = {
488489
/* rest, filled in during DTypeMeta initialization */
489490
};
490491

492+
/*
493+
* Ideally we don't need the copy/pasted type below and we allow the
494+
* StringDType class to be initialized dynamically however, that requires that
495+
* the class is a heap type, and that likely only works well with some changes
496+
* in numpy in Python 3.12 or newer. In practice we really only need a version
497+
* with a numpy-specific NA and a and a version that uses Pandas' NA object,
498+
* so we just define PandasStringDType statically as well and live with a bit
499+
* of copy/paste
500+
*/
501+
502+
PyArray_DTypeMeta PandasStringDType = {
503+
{{
504+
PyVarObject_HEAD_INIT(NULL, 0).tp_name =
505+
"stringdtype.PandasStringDType",
506+
.tp_basicsize = sizeof(StringDTypeObject),
507+
.tp_new = stringdtype_new,
508+
.tp_dealloc = (destructor)stringdtype_dealloc,
509+
.tp_repr = (reprfunc)stringdtype_repr,
510+
.tp_str = (reprfunc)stringdtype_repr,
511+
.tp_methods = StringDType_methods,
512+
}},
513+
/* rest, filled in during DTypeMeta initialization */
514+
};
515+
491516
int
492517
init_string_dtype(void)
493518
{
@@ -522,6 +547,51 @@ init_string_dtype(void)
522547

523548
StringDType.singleton = singleton;
524549

550+
/* and once again for PandasStringDType */
551+
552+
PyObject *mod = PyImport_ImportModule("pandas");
553+
554+
if (mod != NULL) {
555+
PyArrayDTypeMeta_Spec PandasStringDType_DTypeSpec = {
556+
.typeobj = PandasStringScalar_Type,
557+
.slots = StringDType_Slots,
558+
.casts = casts,
559+
};
560+
561+
PyObject *pandas_na_obj = PyObject_GetAttrString(mod, "NA");
562+
563+
Py_DECREF(mod);
564+
565+
if (pandas_na_obj == NULL) {
566+
return -1;
567+
}
568+
569+
((PyObject *)&PandasStringDType)->ob_type = &PyArrayDTypeMeta_Type;
570+
((PyTypeObject *)&PandasStringDType)->tp_base = &PyArrayDescr_Type;
571+
((PyTypeObject *)&PandasStringDType)->tp_dict = PyDict_New();
572+
PyDict_SetItemString(((PyTypeObject *)&PandasStringDType)->tp_dict,
573+
"na_object", pandas_na_obj);
574+
if (PyType_Ready((PyTypeObject *)&PandasStringDType) < 0) {
575+
return -1;
576+
}
577+
578+
if (PyArrayInitDTypeMeta_FromSpec(&PandasStringDType,
579+
&PandasStringDType_DTypeSpec) < 0) {
580+
return -1;
581+
}
582+
583+
singleton = PyArray_GetDefaultDescr(&PandasStringDType);
584+
585+
if (singleton == NULL) {
586+
return -1;
587+
}
588+
589+
PandasStringDType.singleton = singleton;
590+
}
591+
else {
592+
PandasStringDType = StringDType;
593+
}
594+
525595
for (int i = 0; casts[i] != NULL; i++) {
526596
free(casts[i]->dtypes);
527597
free(casts[i]);

stringdtype/stringdtype/src/dtype.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ typedef struct {
1919
} StringDTypeObject;
2020

2121
extern PyArray_DTypeMeta StringDType;
22+
extern PyArray_DTypeMeta PandasStringDType;
2223
extern PyTypeObject *StringScalar_Type;
24+
extern PyTypeObject *PandasStringScalar_Type;
2325
extern PyObject *NA_OBJ;
2426

2527
PyObject *

stringdtype/stringdtype/src/main.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,20 @@ PyInit__main(void)
107107

108108
StringScalar_Type =
109109
(PyTypeObject *)PyObject_GetAttrString(mod, "StringScalar");
110-
Py_DECREF(mod);
111110

112111
if (StringScalar_Type == NULL) {
113112
goto error;
114113
}
115114

115+
PandasStringScalar_Type =
116+
(PyTypeObject *)PyObject_GetAttrString(mod, "PandasStringScalar");
117+
118+
Py_DECREF(mod);
119+
120+
if (PandasStringScalar_Type == NULL) {
121+
goto error;
122+
}
123+
116124
if (init_string_na_object(mod) < 0) {
117125
goto error;
118126
}
@@ -127,6 +135,13 @@ PyInit__main(void)
127135
goto error;
128136
}
129137

138+
Py_INCREF((PyObject *)&PandasStringDType);
139+
if (PyModule_AddObject(m, "PandasStringDType",
140+
(PyObject *)&PandasStringDType) < 0) {
141+
Py_DECREF((PyObject *)&PandasStringDType);
142+
goto error;
143+
}
144+
130145
if (init_ufuncs() < 0) {
131146
goto error;
132147
}

stringdtype/tests/test_stringdtype.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,15 @@ def test_create_with_na(na_val):
338338
== "array(['hello', stringdtype.NA, 'world'], dtype=StringDType())"
339339
)
340340
assert arr[1] == NA and arr[1] is NA
341+
342+
343+
def test_pandas_string_dtype():
344+
pandas = pytest.importorskip("pandas")
345+
from stringdtype import PandasStringDType
346+
347+
assert PandasStringDType.na_object is pandas.NA
348+
349+
string_list = ["hello", np.nan, "world"]
350+
arr = np.array(string_list, dtype=PandasStringDType())
351+
352+
assert arr[1] is pandas.NA

0 commit comments

Comments
 (0)