Skip to content

Commit 02d6e26

Browse files
committed
expand tests for PandasStringDType
1 parent cb12e99 commit 02d6e26

File tree

5 files changed

+221
-118
lines changed

5 files changed

+221
-118
lines changed

stringdtype/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ build-backend = "mesonpy"
1111
[tool.black]
1212
line-length = 79
1313

14+
[tool.isort]
15+
profile = "black"
16+
line_length = 79
17+
1418
[project]
1519
name = "stringdtype"
1620
description = "A dtype for storing UTF-8 strings"

stringdtype/stringdtype/src/dtype.c

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,16 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
6363
return (PyArray_DTypeMeta *)Py_NotImplemented;
6464
}
6565

66-
// For a given python object, this function returns a borrowed reference
67-
// to the dtype property of the array
68-
static PyArray_Descr *
69-
string_discover_descriptor_from_pyobject(PyTypeObject *cls, PyObject *obj)
70-
{
71-
if (Py_TYPE(obj) != StringScalar_Type) {
72-
PyErr_SetString(PyExc_TypeError,
73-
"Can only store StringScalar in a StringDType array.");
74-
return NULL;
75-
}
76-
77-
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(cls);
78-
if (ret == NULL) {
79-
return NULL;
80-
}
81-
return ret;
82-
}
83-
8466
static PyObject *
85-
get_value(PyObject *scalar, PyObject *na_object)
67+
get_value(PyObject *scalar, StringDType_type *cls)
8668
{
69+
PyObject *na_object = cls->na_object;
8770
PyObject *ret = NULL;
71+
PyTypeObject *expected_scalar_type = cls->base.scalar_type;
8872
PyTypeObject *scalar_type = Py_TYPE(scalar);
8973
// FIXME: handle bytes too
9074
if ((scalar_type == &PyUnicode_Type) ||
91-
(scalar_type == StringScalar_Type)) {
75+
(scalar_type == expected_scalar_type)) {
9276
// attempt to decode as UTF8
9377
ret = PyUnicode_AsUTF8String(scalar);
9478
if (ret == NULL) {
@@ -127,14 +111,31 @@ get_value(PyObject *scalar, PyObject *na_object)
127111
return ret;
128112
}
129113

114+
// For a given python object, this function returns a borrowed reference
115+
// to the dtype property of the array
116+
static PyArray_Descr *
117+
string_discover_descriptor_from_pyobject(PyTypeObject *cls, PyObject *obj)
118+
{
119+
PyObject *val = get_value(obj, (StringDType_type *)cls);
120+
if (val == NULL) {
121+
return NULL;
122+
}
123+
124+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance(cls);
125+
if (ret == NULL) {
126+
return NULL;
127+
}
128+
return ret;
129+
}
130+
130131
// Take a python object `obj` and insert it into the array of dtype `descr` at
131132
// the position given by dataptr.
132133
static int
133134
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
134135
{
135136
// borrow reference
136137
PyObject *na_object = ((StringDType_type *)Py_TYPE(descr))->na_object;
137-
PyObject *val_obj = get_value(obj, na_object);
138+
PyObject *val_obj = get_value(obj, (StringDType_type *)Py_TYPE(descr));
138139

139140
if (val_obj == NULL) {
140141
return -1;
@@ -396,9 +397,10 @@ stringdtype_repr(StringDTypeObject *self)
396397
static int PICKLE_VERSION = 1;
397398

398399
static PyObject *
399-
stringdtype__reduce__(StringDTypeObject *NPY_UNUSED(self))
400+
stringdtype__reduce__(StringDTypeObject *self)
400401
{
401402
PyObject *ret, *mod, *obj, *state;
403+
StringDType_type *s_type = (StringDType_type *)Py_TYPE(self);
402404

403405
ret = PyTuple_New(3);
404406
if (ret == NULL) {
@@ -411,7 +413,12 @@ stringdtype__reduce__(StringDTypeObject *NPY_UNUSED(self))
411413
return NULL;
412414
}
413415

414-
obj = PyObject_GetAttrString(mod, "StringDType");
416+
if (s_type->na_object == NA_OBJ) {
417+
obj = PyObject_GetAttrString(mod, "StringDType");
418+
}
419+
else {
420+
obj = PyObject_GetAttrString(mod, "PandasStringDType");
421+
}
415422
Py_DECREF(mod);
416423
if (obj == NULL) {
417424
Py_DECREF(ret);

stringdtype/stringdtype/src/main.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ _memory_usage(PyObject *NPY_UNUSED(self), PyObject *obj)
2323
PyArray_Descr *descr = PyArray_DESCR(arr);
2424
PyArray_DTypeMeta *dtype = NPY_DTYPE(descr);
2525

26-
if (dtype != (PyArray_DTypeMeta *)&StringDType) {
26+
if (dtype != (PyArray_DTypeMeta *)&StringDType &&
27+
dtype != (PyArray_DTypeMeta *)&PandasStringDType) {
2728
PyErr_SetString(PyExc_TypeError,
2829
"can only be called with a StringDType array");
2930
return NULL;

stringdtype/stringdtype/src/umath.c

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,17 @@ init_ufuncs(void)
391391
goto error;
392392
}
393393

394+
PyArray_DTypeMeta *peq_dtypes[] = {(PyArray_DTypeMeta *)&PandasStringDType,
395+
(PyArray_DTypeMeta *)&PandasStringDType,
396+
&PyArray_BoolDType};
397+
398+
if (init_ufunc(numpy, "equal", peq_dtypes,
399+
&string_equal_resolve_descriptors,
400+
&string_equal_strided_loop, "string_equal", 2, 1,
401+
NPY_NO_CASTING, 0) < 0) {
402+
goto error;
403+
}
404+
394405
PyArray_DTypeMeta *promoter_dtypes[2][3] = {
395406
{(PyArray_DTypeMeta *)&StringDType, &PyArray_UnicodeDType,
396407
&PyArray_BoolDType},
@@ -406,6 +417,21 @@ init_ufuncs(void)
406417
goto error;
407418
}
408419

420+
PyArray_DTypeMeta *p_promoter_dtypes[2][3] = {
421+
{(PyArray_DTypeMeta *)&PandasStringDType, &PyArray_UnicodeDType,
422+
&PyArray_BoolDType},
423+
{&PyArray_UnicodeDType, (PyArray_DTypeMeta *)&PandasStringDType,
424+
&PyArray_BoolDType},
425+
};
426+
427+
if (add_promoter(numpy, "equal", p_promoter_dtypes[0]) < 0) {
428+
goto error;
429+
}
430+
431+
if (add_promoter(numpy, "equal", p_promoter_dtypes[1]) < 0) {
432+
goto error;
433+
}
434+
409435
PyArray_DTypeMeta *isnan_dtypes[] = {(PyArray_DTypeMeta *)&StringDType,
410436
&PyArray_BoolDType};
411437

@@ -416,11 +442,22 @@ init_ufuncs(void)
416442
goto error;
417443
}
418444

445+
PyArray_DTypeMeta *p_isnan_dtypes[] = {
446+
(PyArray_DTypeMeta *)&PandasStringDType, &PyArray_BoolDType};
447+
448+
if (init_ufunc(numpy, "isnan", p_isnan_dtypes,
449+
&string_isnan_resolve_descriptors,
450+
&string_isnan_strided_loop, "string_isnan", 1, 1,
451+
NPY_NO_CASTING, 0) < 0) {
452+
goto error;
453+
}
454+
419455
PyArray_DTypeMeta *minmax_dtypes[] = {
420456
(PyArray_DTypeMeta *)&StringDType,
421457
(PyArray_DTypeMeta *)&StringDType,
422458
(PyArray_DTypeMeta *)&StringDType,
423459
};
460+
424461
if (init_ufunc(numpy, "maximum", minmax_dtypes, NULL,
425462
&maximum_strided_loop, "string_maximum", 2, 1,
426463
NPY_NO_CASTING, 0) < 0) {
@@ -432,16 +469,46 @@ init_ufuncs(void)
432469
goto error;
433470
}
434471

472+
PyArray_DTypeMeta *p_minmax_dtypes[] = {
473+
(PyArray_DTypeMeta *)&PandasStringDType,
474+
(PyArray_DTypeMeta *)&PandasStringDType,
475+
(PyArray_DTypeMeta *)&PandasStringDType,
476+
};
477+
478+
if (init_ufunc(numpy, "maximum", p_minmax_dtypes, NULL,
479+
&maximum_strided_loop, "string_maximum", 2, 1,
480+
NPY_NO_CASTING, 0) < 0) {
481+
goto error;
482+
}
483+
484+
if (init_ufunc(numpy, "minimum", p_minmax_dtypes, NULL,
485+
&minimum_strided_loop, "string_minimum", 2, 1,
486+
NPY_NO_CASTING, 0) < 0) {
487+
goto error;
488+
}
489+
435490
PyArray_DTypeMeta *add_dtypes[] = {
436491
(PyArray_DTypeMeta *)&StringDType,
437492
(PyArray_DTypeMeta *)&StringDType,
438493
(PyArray_DTypeMeta *)&StringDType,
439494
};
495+
440496
if (init_ufunc(numpy, "add", add_dtypes, NULL, &add_strided_loop,
441497
"string_add", 2, 1, NPY_NO_CASTING, 0) < 0) {
442498
goto error;
443499
}
444500

501+
PyArray_DTypeMeta *p_add_dtypes[] = {
502+
(PyArray_DTypeMeta *)&PandasStringDType,
503+
(PyArray_DTypeMeta *)&PandasStringDType,
504+
(PyArray_DTypeMeta *)&PandasStringDType,
505+
};
506+
507+
if (init_ufunc(numpy, "add", p_add_dtypes, NULL, &add_strided_loop,
508+
"string_add", 2, 1, NPY_NO_CASTING, 0) < 0) {
509+
goto error;
510+
}
511+
445512
Py_DECREF(numpy);
446513
return 0;
447514

0 commit comments

Comments
 (0)