Skip to content

Commit c50a701

Browse files
committed
initialize stringdtype instances using a generic type
1 parent dcca859 commit c50a701

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,18 @@ static char *s2s_name = "cast_StringDType_to_StringDType";
8080

8181
static NPY_CASTING
8282
unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
83-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
83+
PyArray_DTypeMeta *dtypes[2],
8484
PyArray_Descr *given_descrs[2],
8585
PyArray_Descr *loop_descrs[2],
8686
npy_intp *NPY_UNUSED(view_offset))
8787
{
8888
if (given_descrs[1] == NULL) {
89-
StringDTypeObject *new = new_stringdtype_instance();
89+
PyArray_Descr *new = (PyArray_Descr *)new_stringdtype_instance(
90+
(PyTypeObject *)dtypes[1]);
9091
if (new == NULL) {
9192
return (NPY_CASTING)-1;
9293
}
93-
loop_descrs[1] = (PyArray_Descr *)new;
94+
loop_descrs[1] = new;
9495
}
9596
else {
9697
Py_INCREF(given_descrs[1]);

stringdtype/stringdtype/src/dtype.c

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ PyObject *NA_OBJ = NULL;
1010
/*
1111
* Internal helper to create new instances
1212
*/
13-
StringDTypeObject *
14-
new_stringdtype_instance(void)
13+
PyObject *
14+
new_stringdtype_instance(PyTypeObject *cls)
1515
{
16-
StringDTypeObject *new = (StringDTypeObject *)PyArrayDescr_Type.tp_new(
17-
(PyTypeObject *)&StringDType, NULL, NULL);
16+
PyObject *new = PyArrayDescr_Type.tp_new((PyTypeObject *)cls, NULL, NULL);
1817
if (new == NULL) {
1918
return NULL;
2019
}
21-
new->base.elsize = sizeof(ss);
22-
new->base.alignment = _Alignof(ss);
23-
new->base.flags |= NPY_NEEDS_INIT;
24-
new->base.flags |= NPY_LIST_PICKLE;
25-
new->base.flags |= NPY_ITEM_REFCOUNT;
20+
21+
PyArray_Descr *base = &((StringDTypeObject *)new)->base;
22+
base->elsize = sizeof(ss);
23+
base->alignment = _Alignof(ss);
24+
base->flags |= NPY_NEEDS_INIT;
25+
base->flags |= NPY_LIST_PICKLE;
26+
base->flags |= NPY_ITEM_REFCOUNT;
2627

2728
return new;
2829
}
@@ -63,16 +64,15 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6364
// For a given python object, this function returns a borrowed reference
6465
// to the dtype property of the array
6566
static PyArray_Descr *
66-
string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
67-
PyObject *obj)
67+
string_discover_descriptor_from_pyobject(PyTypeObject *cls, PyObject *obj)
6868
{
6969
if (Py_TYPE(obj) != StringScalar_Type) {
7070
PyErr_SetString(PyExc_TypeError,
7171
"Can only store StringScalar in a StringDType array.");
7272
return NULL;
7373
}
7474

75-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
75+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(cls);
7676
if (ret == NULL) {
7777
return NULL;
7878
}
@@ -354,7 +354,7 @@ static PyType_Slot StringDType_Slots[] = {
354354
{0, NULL}};
355355

356356
static PyObject *
357-
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
357+
stringdtype_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
358358
{
359359
static char *kwargs_strs[] = {"size", NULL};
360360

@@ -365,7 +365,7 @@ stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
365365
return NULL;
366366
}
367367

368-
PyObject *ret = (PyObject *)new_stringdtype_instance();
368+
PyObject *ret = new_stringdtype_instance(cls);
369369

370370
return ret;
371371
}
@@ -494,7 +494,6 @@ init_string_dtype(void)
494494
PyArrayMethod_Spec **casts = get_casts();
495495

496496
PyArrayDTypeMeta_Spec StringDType_DTypeSpec = {
497-
.flags = NPY_DT_PARAMETRIC,
498497
.typeobj = StringScalar_Type,
499498
.slots = StringDType_Slots,
500499
.casts = casts,

stringdtype/stringdtype/src/dtype.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616

1717
typedef struct {
1818
PyArray_Descr base;
19-
PyObject *na_object;
2019
} StringDTypeObject;
2120

2221
extern PyArray_DTypeMeta StringDType;
2322
extern PyTypeObject *StringScalar_Type;
2423
extern PyObject *NA_OBJ;
2524

26-
StringDTypeObject *
27-
new_stringdtype_instance(void);
25+
PyObject *
26+
new_stringdtype_instance(PyTypeObject *cls);
2827

2928
int
3029
init_string_dtype(void);

stringdtype/stringdtype/src/umath.c

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@
1515

1616
static NPY_CASTING
1717
binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
18+
PyArray_DTypeMeta *dtypes[],
1919
PyArray_Descr *given_descrs[],
2020
PyArray_Descr *loop_descrs[],
2121
npy_intp *NPY_UNUSED(view_offset))
2222
{
23-
PyObject *na_obj1 = ((StringDTypeObject *)given_descrs[0])->na_object;
24-
PyObject *na_obj2 = ((StringDTypeObject *)given_descrs[1])->na_object;
25-
26-
int eq_res = PyObject_RichCompareBool(na_obj1, na_obj2, Py_EQ);
27-
28-
if (eq_res < 0) {
29-
return (NPY_CASTING)-1;
30-
}
31-
32-
if (eq_res != 1) {
23+
// technically incorrect to cast to StringDType if we have a
24+
// PandasStringDType but they have the same layout so this should be fine.
25+
PyObject *na_obj1 = PyDict_GetItemString(
26+
((PyTypeObject *)dtypes[0])->tp_dict, "na_object");
27+
PyObject *na_obj2 = PyDict_GetItemString(
28+
((PyTypeObject *)dtypes[1])->tp_dict, "na_object");
29+
30+
if (na_obj1 != na_obj2) {
3331
PyErr_SetString(PyExc_TypeError,
3432
"Can only do binary operations with identical "
3533
"StringDType instances.");

0 commit comments

Comments
 (0)