@@ -21,20 +21,16 @@ new_stringdtype_instance(void)
21
21
return new ;
22
22
}
23
23
24
- //
25
- // This is used to determine the correct dtype to return when operations mix
26
- // dtypes (I think?). For now just return the first one.
27
- //
24
+ /*
25
+ * This is used to determine the correct dtype to return when dealing
26
+ * with a mix of different dtypes (for example when creating an array
27
+ * from a list of scalars). Since StringDType doesn't have any parameters,
28
+ * we can safely always return the first one.
29
+ */
28
30
static StringDTypeObject *
29
31
common_instance (StringDTypeObject * dtype1 , StringDTypeObject * dtype2 )
30
32
{
31
- if (!PyObject_RichCompareBool ((PyObject * )dtype1 , (PyObject * )dtype2 ,
32
- Py_EQ )) {
33
- PyErr_SetString (
34
- PyExc_RuntimeError ,
35
- "common_instance called on unequal StringDType instances" );
36
- return NULL ;
37
- }
33
+ Py_INCREF (dtype1 );
38
34
return dtype1 ;
39
35
}
40
36
@@ -66,34 +62,65 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
66
62
return ret ;
67
63
}
68
64
65
+ static PyObject *
66
+ get_value (PyObject * scalar )
67
+ {
68
+ PyObject * ret_bytes = NULL ;
69
+ PyTypeObject * scalar_type = Py_TYPE (scalar );
70
+ // FIXME: handle bytes too
71
+ if ((scalar_type == & PyUnicode_Type ) ||
72
+ (scalar_type == StringScalar_Type )) {
73
+ // attempt to decode as UTF8
74
+ ret_bytes = PyUnicode_AsUTF8String (scalar );
75
+ if (ret_bytes == NULL ) {
76
+ PyErr_SetString (
77
+ PyExc_TypeError ,
78
+ "Can only store UTF8 text in a StringDType array." );
79
+ return NULL ;
80
+ }
81
+ }
82
+ else {
83
+ PyErr_SetString (PyExc_TypeError ,
84
+ "Can only store String text in a StringDType array." );
85
+ return NULL ;
86
+ }
87
+ return ret_bytes ;
88
+ }
89
+
69
90
// Take a python object `obj` and insert it into the array of dtype `descr` at
70
91
// the position given by dataptr.
71
92
static int
72
93
stringdtype_setitem (StringDTypeObject * descr , PyObject * obj , char * * dataptr )
73
94
{
74
- char * val = PyBytes_AsString (obj );
75
- if (val == NULL ) {
95
+ PyObject * val_obj = get_value (obj );
96
+ char * val = NULL ;
97
+ Py_ssize_t length = 0 ;
98
+ if (PyBytes_AsStringAndSize (val_obj , & val , & length ) == -1 ) {
76
99
return -1 ;
77
100
}
78
101
79
- * dataptr = malloc (sizeof (char ) * strlen (val ));
80
- strcpy (* dataptr , val );
102
+ * dataptr = malloc (sizeof (char ) * length + 1 );
103
+ strncpy (* dataptr , val , length + 1 );
104
+ Py_DECREF (val_obj );
81
105
return 0 ;
82
106
}
83
107
84
108
static PyObject *
85
109
stringdtype_getitem (StringDTypeObject * descr , char * * dataptr )
86
110
{
87
- PyObject * val_obj = PyBytes_FromString (* dataptr );
111
+ PyObject * val_obj = PyUnicode_FromString (* dataptr );
112
+
88
113
if (val_obj == NULL ) {
89
114
return NULL ;
90
115
}
91
116
92
117
PyObject * res = PyObject_CallFunctionObjArgs ((PyObject * )StringScalar_Type ,
93
- val_obj , NULL );
118
+ val_obj , descr , NULL );
119
+
94
120
if (res == NULL ) {
95
121
return NULL ;
96
122
}
123
+
97
124
Py_DECREF (val_obj );
98
125
99
126
return res ;
@@ -119,6 +146,15 @@ static PyType_Slot StringDType_Slots[] = {
119
146
static PyObject *
120
147
stringdtype_new (PyTypeObject * NPY_UNUSED (cls ), PyObject * args , PyObject * kwds )
121
148
{
149
+ static char * kwargs_strs [] = {"size" , NULL };
150
+
151
+ long size = 0 ;
152
+
153
+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "|l:StringDType" , kwargs_strs ,
154
+ & size )) {
155
+ return NULL ;
156
+ }
157
+
122
158
return (PyObject * )new_stringdtype_instance ();
123
159
}
124
160
0 commit comments