Skip to content

Commit 865c973

Browse files
committed
fixes for StringDType
1 parent 3af3502 commit 865c973

File tree

5 files changed

+96
-32
lines changed

5 files changed

+96
-32
lines changed

asciidtype/asciidtype/scalar.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,11 @@ def __new__(cls, value, dtype):
66
instance = super().__new__(cls, value)
77
instance.dtype = dtype
88
return instance
9+
10+
def partition(self, sep):
11+
ret = super().partition(sep)
12+
return (str(ret[0]), str(ret[1]), str(ret[2]))
13+
14+
def rpartition(self, sep):
15+
ret = super().rpartition(sep)
16+
return (str(ret[0]), str(ret[1]), str(ret[2]))

asciidtype/asciidtype/src/dtype.c

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,31 @@ static PyArray_Descr *
8888
ascii_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
8989
PyObject *obj)
9090
{
91-
if (Py_TYPE(obj) != ASCIIScalar_Type) {
91+
PyTypeObject *obj_type = Py_TYPE(obj);
92+
PyArray_Descr *ret = NULL;
93+
if (obj_type != ASCIIScalar_Type) {
94+
if (PyUnicode_Check(obj)) {
95+
if (!PyUnicode_IS_ASCII(obj)) {
96+
PyErr_SetString(
97+
PyExc_TypeError,
98+
"Can only store strings or bytes convertible to ASCII "
99+
"in a ASCIIDType array.");
100+
return NULL;
101+
}
102+
ret = (PyArray_Descr *)new_asciidtype_instance(
103+
(long)PyUnicode_GetLength(obj));
104+
}
105+
// could do bytes too if we want
92106
PyErr_SetString(PyExc_TypeError,
93-
"Can only store ASCIIScalar in a ASCIIDType array.");
107+
"Can only store strings or bytes convertible to ASCII "
108+
"in a ASCIIDType array.");
94109
return NULL;
95110
}
96-
97-
PyArray_Descr *ret = (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
98-
if (ret == NULL) {
99-
return NULL;
111+
else {
112+
ret = (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
113+
if (ret == NULL) {
114+
return NULL;
115+
}
100116
}
101117
return ret;
102118
}

stringdtype/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# A dtype that stores ASCII data
1+
# A dtype that stores pointers to strings
22

3-
This is a simple proof-of-concept dtype using the (as of late 2022) experimental
3+
This is a simple proof-of-concept dtype using the (as of early 2023) experimental
44
[new dtype
55
implementation](https://numpy.org/neps/nep-0041-improved-dtype-support.html) in
66
NumPy.

stringdtype/stringdtype/scalar.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""A scalar type needed by the dtype machinery."""
22

33

4-
class StringScalar:
5-
def __init__(self, value):
6-
self.value = value
4+
class StringScalar(str):
5+
def __new__(cls, value, dtype):
6+
instance = super().__new__(cls, value)
7+
instance.dtype = dtype
8+
return instance
79

8-
def __str__(self):
9-
return str(self.value)
10+
def partition(self, sep):
11+
ret = super().partition(sep)
12+
return (str(ret[0]), str(ret[1]), str(ret[2]))
1013

11-
def __repr__(self):
12-
return repr(self.value)
14+
def rpartition(self, sep):
15+
ret = super().rpartition(sep)
16+
return (str(ret[0]), str(ret[1]), str(ret[2]))

stringdtype/stringdtype/src/dtype.c

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,16 @@ new_stringdtype_instance(void)
2121
return new;
2222
}
2323

24-
//
25-
// This is used to determine the correct dtype to return when operations mix
26-
// dtypes (I think?). For now just return the first one.
27-
//
24+
/*
25+
* This is used to determine the correct dtype to return when dealing
26+
* with a mix of different dtypes (for example when creating an array
27+
* from a list of scalars). Since StringDType doesn't have any parameters,
28+
* we can safely always return the first one.
29+
*/
2830
static StringDTypeObject *
2931
common_instance(StringDTypeObject *dtype1, StringDTypeObject *dtype2)
3032
{
31-
if (!PyObject_RichCompareBool((PyObject *)dtype1, (PyObject *)dtype2,
32-
Py_EQ)) {
33-
PyErr_SetString(
34-
PyExc_RuntimeError,
35-
"common_instance called on unequal StringDType instances");
36-
return NULL;
37-
}
33+
Py_INCREF(dtype1);
3834
return dtype1;
3935
}
4036

@@ -66,34 +62,65 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
6662
return ret;
6763
}
6864

65+
static PyObject *
66+
get_value(PyObject *scalar)
67+
{
68+
PyObject *ret_bytes = NULL;
69+
PyTypeObject *scalar_type = Py_TYPE(scalar);
70+
// FIXME: handle bytes too
71+
if ((scalar_type == &PyUnicode_Type) ||
72+
(scalar_type == StringScalar_Type)) {
73+
// attempt to decode as UTF8
74+
ret_bytes = PyUnicode_AsUTF8String(scalar);
75+
if (ret_bytes == NULL) {
76+
PyErr_SetString(
77+
PyExc_TypeError,
78+
"Can only store UTF8 text in a StringDType array.");
79+
return NULL;
80+
}
81+
}
82+
else {
83+
PyErr_SetString(PyExc_TypeError,
84+
"Can only store String text in a StringDType array.");
85+
return NULL;
86+
}
87+
return ret_bytes;
88+
}
89+
6990
// Take a python object `obj` and insert it into the array of dtype `descr` at
7091
// the position given by dataptr.
7192
static int
7293
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
7394
{
74-
char *val = PyBytes_AsString(obj);
75-
if (val == NULL) {
95+
PyObject *val_obj = get_value(obj);
96+
char *val = NULL;
97+
Py_ssize_t length = 0;
98+
if (PyBytes_AsStringAndSize(val_obj, &val, &length) == -1) {
7699
return -1;
77100
}
78101

79-
*dataptr = malloc(sizeof(char) * strlen(val));
80-
strcpy(*dataptr, val);
102+
*dataptr = malloc(sizeof(char) * length + 1);
103+
strncpy(*dataptr, val, length + 1);
104+
Py_DECREF(val_obj);
81105
return 0;
82106
}
83107

84108
static PyObject *
85109
stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
86110
{
87-
PyObject *val_obj = PyBytes_FromString(*dataptr);
111+
PyObject *val_obj = PyUnicode_FromString(*dataptr);
112+
88113
if (val_obj == NULL) {
89114
return NULL;
90115
}
91116

92117
PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)StringScalar_Type,
93-
val_obj, NULL);
118+
val_obj, descr, NULL);
119+
94120
if (res == NULL) {
95121
return NULL;
96122
}
123+
97124
Py_DECREF(val_obj);
98125

99126
return res;
@@ -119,6 +146,15 @@ static PyType_Slot StringDType_Slots[] = {
119146
static PyObject *
120147
stringdtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds)
121148
{
149+
static char *kwargs_strs[] = {"size", NULL};
150+
151+
long size = 0;
152+
153+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:StringDType", kwargs_strs,
154+
&size)) {
155+
return NULL;
156+
}
157+
122158
return (PyObject *)new_stringdtype_instance();
123159
}
124160

0 commit comments

Comments
 (0)