Skip to content

Commit cabab7f

Browse files
committed
add StringDType to Unicode_ cast
1 parent bc3f8be commit cabab7f

File tree

3 files changed

+194
-48
lines changed

3 files changed

+194
-48
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 185 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
#include "dtype.h"
44

55
static NPY_CASTING
6-
string_resolve_descriptors(PyObject *NPY_UNUSED(self),
7-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
8-
PyArray_Descr *given_descrs[2],
9-
PyArray_Descr *loop_descrs[2],
10-
npy_intp *NPY_UNUSED(view_offset))
6+
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
7+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
8+
PyArray_Descr *given_descrs[2],
9+
PyArray_Descr *loop_descrs[2],
10+
npy_intp *NPY_UNUSED(view_offset))
1111
{
1212
Py_INCREF(given_descrs[0]);
1313
loop_descrs[0] = given_descrs[0];
@@ -50,7 +50,7 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
5050
static PyArray_DTypeMeta *s2s_dtypes[2] = {NULL, NULL};
5151

5252
static PyType_Slot s2s_slots[] = {
53-
{NPY_METH_resolve_descriptors, &string_resolve_descriptors},
53+
{NPY_METH_resolve_descriptors, &string_to_string_resolve_descriptors},
5454
{NPY_METH_strided_loop, &string_to_string},
5555
{NPY_METH_unaligned_strided_loop, &string_to_string},
5656
{0, NULL}};
@@ -65,33 +65,57 @@ PyArrayMethod_Spec StringToStringCastSpec = {
6565
.slots = s2s_slots,
6666
};
6767

68+
static NPY_CASTING
69+
unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
70+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
71+
PyArray_Descr *given_descrs[2],
72+
PyArray_Descr *loop_descrs[2],
73+
npy_intp *NPY_UNUSED(view_offset))
74+
{
75+
Py_INCREF(given_descrs[0]);
76+
loop_descrs[0] = given_descrs[0];
77+
78+
if (given_descrs[1] == NULL) {
79+
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance();
80+
}
81+
else {
82+
Py_INCREF(given_descrs[1]);
83+
loop_descrs[1] = given_descrs[1];
84+
}
85+
86+
return NPY_SAFE_CASTING;
87+
}
88+
6889
// converts UCS4 code point to 4-byte char* assumes in is a zero-filled 4 byte
69-
// array returns -1 if the code point is not a valid unicode code point, the
70-
// number of bytes in the in the UTF-8 character on success
90+
// array returns -1 if the code point is not a valid unicode code point,
91+
// returns the number of bytes in the UTF-8 character on success
7192
static int
72-
ucs4_to_utf8_char(const Py_UCS4 code, char *in)
93+
ucs4_code_to_utf8_char(const Py_UCS4 code, char *c)
7394
{
7495
if (code <= 0x7F) {
75-
// ASCII
76-
in[0] = (char)code;
96+
// 0zzzzzzz -> 0zzzzzzz
97+
c[0] = (char)code;
7798
return 1;
7899
}
79100
else if (code <= 0x07FF) {
80-
in[0] = (0xc0 | (code >> 6));
81-
in[1] = (0x80 | (code & 0x3f));
101+
// 00000yyy yyzzzzzz -> 110yyyyy 10zzzzzz
102+
c[0] = (0xC0 | (code >> 6));
103+
c[1] = (0x80 | (code & 0x3F));
82104
return 2;
83105
}
84106
else if (code <= 0xFFFF) {
85-
in[0] = (0xe0 | (code >> 12));
86-
in[1] = (0x80 | ((code >> 6) & 0x3f));
87-
in[2] = (0x80 | (code & 0x3f));
107+
// xxxxyyyy yyzzzzzz -> 110yyyyy 10zzzzzz
108+
c[0] = (0xe0 | (code >> 12));
109+
c[1] = (0x80 | ((code >> 6) & 0x3f));
110+
c[2] = (0x80 | (code & 0x3f));
88111
return 3;
89112
}
90113
else if (code <= 0x10FFFF) {
91-
in[0] = (0xf0 | (code >> 18));
92-
in[1] = (0x80 | ((code >> 12) & 0x3f));
93-
in[2] = (0x80 | ((code >> 6) & 0x3f));
94-
in[3] = (0x80 | (code & 0x3f));
114+
// 00wwwxx xxxxyyyy yyzzzzzz -> 11110www 10xxxxxx 10yyyyyy 10zzzzzz
115+
c[0] = (0xf0 | (code >> 18));
116+
c[1] = (0x80 | ((code >> 12) & 0x3f));
117+
c[2] = (0x80 | ((code >> 6) & 0x3f));
118+
c[3] = (0x80 | (code & 0x3f));
95119
return 4;
96120
}
97121
return -1;
@@ -106,28 +130,31 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
106130
long in_size = (descrs[0]->elsize) / 4;
107131

108132
npy_intp N = dimensions[0];
109-
char *in = data[0];
133+
Py_UCS4 *in = (Py_UCS4 *)data[0];
110134
char **out = (char **)data[1];
111-
npy_intp in_stride = strides[0];
135+
136+
// 4 bytes per UCS4 character
137+
npy_intp in_stride = strides[0] / 4;
112138
// strides are in bytes but pointer offsets are in pointer widths, so
113139
// divide by the element size (one pointer width) to get the pointer offset
114140
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
115141

116142
while (N--) {
117143
// pessimistically allocate 4 bytes per allowed character
118-
char *out_buf = calloc(in_size * 4 + 1, sizeof(char));
144+
// plus one byte for the null terminator
145+
char *out_buf = malloc((in_size * 4 + 1) * sizeof(char));
119146
size_t out_num_bytes = 0;
120147
for (int i = 0; i < in_size; i++) {
121148
// get code point
122-
Py_UCS4 code = ((Py_UCS4 *)in)[i];
149+
Py_UCS4 code = in[i];
123150

124151
if (code == 0) {
125152
break;
126153
}
127154

128155
// convert codepoint to UTF8 bytes
129156
char utf8_c[4] = {0};
130-
size_t num_bytes = ucs4_to_utf8_char(code, utf8_c);
157+
size_t num_bytes = ucs4_code_to_utf8_char(code, utf8_c);
131158
out_num_bytes += num_bytes;
132159

133160
if (num_bytes == -1) {
@@ -159,7 +186,6 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
159186
// set out to the address of the beginning of the string
160187
out[0] = out_buf;
161188

162-
// increment out and in by strides
163189
in += in_stride;
164190
out += out_stride;
165191
}
@@ -168,12 +194,126 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
168194
}
169195

170196
static PyType_Slot u2s_slots[] = {
171-
{NPY_METH_resolve_descriptors, &string_resolve_descriptors},
197+
{NPY_METH_resolve_descriptors, &unicode_to_string_resolve_descriptors},
172198
{NPY_METH_strided_loop, &unicode_to_string},
173199
{0, NULL}};
174200

175201
static char *u2s_name = "cast_Unicode_to_StringDType";
176202

203+
static NPY_CASTING
204+
string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
205+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
206+
PyArray_Descr *given_descrs[2],
207+
PyArray_Descr *loop_descrs[2],
208+
npy_intp *NPY_UNUSED(view_offset))
209+
{
210+
Py_INCREF(given_descrs[0]);
211+
loop_descrs[0] = given_descrs[0];
212+
213+
if (given_descrs[1] == NULL) {
214+
// currently there's no way to determine the correct output
215+
// size, so set an error and bail
216+
PyErr_SetString(
217+
PyExc_TypeError,
218+
"Casting from StringDType to a fixed-width dtype with an "
219+
"unspecified size is not currently supported, specify "
220+
"an explicit size for the output dtype instead.");
221+
return (NPY_CASTING)-1;
222+
}
223+
else {
224+
Py_INCREF(given_descrs[1]);
225+
loop_descrs[1] = given_descrs[1];
226+
}
227+
228+
return NPY_UNSAFE_CASTING;
229+
}
230+
231+
// Given UTF-8 bytes in *c*, sets *codepoint* to the corresponding unicode
232+
// codepoint for the next character, returning the size of the character in
233+
// bytes. Does not do any validation or error checking: assumes *c* is valid
234+
// utf-8
235+
static size_t
236+
utf8_char_to_ucs4_code(unsigned char *c, Py_UCS4 *code)
237+
{
238+
if (c[0] <= 0x7F) {
239+
// 0zzzzzzz -> 0zzzzzzz
240+
*code = (Py_UCS4)(c[0]);
241+
return 1;
242+
}
243+
else if (c[0] <= 0xDF) {
244+
// 110yyyyy 10zzzzzz -> 00000yyy yyzzzzzz
245+
*code = (Py_UCS4)(((c[0] << 6) + c[1]) - ((0xC0 << 6) + 0x80));
246+
return 2;
247+
}
248+
else if (c[0] <= 0xEF) {
249+
// 1110xxxx 10yyyyyy 10zzzzzz -> xxxxyyyy yyzzzzzz
250+
*code = (Py_UCS4)(((c[0] << 12) + (c[1] << 6) + c[2]) -
251+
((0xE0 << 12) + (0x80 << 6) + 0x80));
252+
return 3;
253+
}
254+
else {
255+
// 11110www 10xxxxxx 10yyyyyy 10zzzzzz -> 000wwwxx xxxxyyyy yyzzzzzz
256+
*code = (Py_UCS4)(((c[0] << 18) + (c[1] << 12) + (c[2] << 6) + c[3]) -
257+
((0xF0 << 18) + (0x80 << 12) + (0x80 << 6) + 0x80));
258+
return 4;
259+
}
260+
}
261+
262+
static int
263+
string_to_unicode(PyArrayMethod_Context *context, char *const data[],
264+
npy_intp const dimensions[], npy_intp const strides[],
265+
NpyAuxData *NPY_UNUSED(auxdata))
266+
{
267+
npy_intp N = dimensions[0];
268+
char **in = (char **)data[0];
269+
Py_UCS4 *out = (Py_UCS4 *)data[1];
270+
// strides are in bytes but pointer offsets are in pointer widths, so
271+
// divide by the element size (one pointer width) to get the pointer offset
272+
npy_intp in_stride = strides[0] / context->descriptors[0]->elsize;
273+
// 4 bytes per UCS4 character
274+
npy_intp out_stride = strides[1] / 4;
275+
// max number of 4 byte UCS4 characters that can fit in the output
276+
long max_out_size = (context->descriptors[1]->elsize) / 4;
277+
278+
while (N--) {
279+
unsigned char *this_string = (unsigned char *)*in;
280+
281+
for (int i = 0; i < max_out_size; i++) {
282+
Py_UCS4 code;
283+
284+
// get code point for character this_string is currently pointing
285+
// too
286+
size_t num_bytes = utf8_char_to_ucs4_code(this_string, &code);
287+
288+
// move to next character
289+
this_string += num_bytes;
290+
291+
// set output codepoint
292+
out[i] = code;
293+
294+
// check if this is the null terminator
295+
if (code == 0) {
296+
// fill all remaining characters (if any) with zero
297+
for (int j = i + 1; j < max_out_size; j++) {
298+
out[j] = 0;
299+
}
300+
break;
301+
}
302+
}
303+
in += in_stride;
304+
out += out_stride;
305+
}
306+
307+
return 0;
308+
}
309+
310+
static PyType_Slot s2u_slots[] = {
311+
{NPY_METH_resolve_descriptors, &string_to_unicode_resolve_descriptors},
312+
{NPY_METH_strided_loop, &string_to_unicode},
313+
{0, NULL}};
314+
315+
static char *s2u_name = "cast_StringDType_to_Unicode";
316+
177317
PyArrayMethod_Spec **
178318
get_casts(void)
179319
{
@@ -192,10 +332,26 @@ get_casts(void)
192332
UnicodeToStringCastSpec->dtypes = u2s_dtypes;
193333
UnicodeToStringCastSpec->slots = u2s_slots;
194334

195-
PyArrayMethod_Spec **casts = malloc(3 * sizeof(PyArrayMethod_Spec *));
335+
PyArray_DTypeMeta **s2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
336+
s2u_dtypes[0] = NULL;
337+
s2u_dtypes[1] = &PyArray_UnicodeDType;
338+
339+
PyArrayMethod_Spec *StringToUnicodeCastSpec =
340+
malloc(sizeof(PyArrayMethod_Spec));
341+
342+
StringToUnicodeCastSpec->name = s2u_name;
343+
StringToUnicodeCastSpec->nin = 1;
344+
StringToUnicodeCastSpec->nout = 1;
345+
StringToUnicodeCastSpec->casting = NPY_SAFE_CASTING;
346+
StringToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
347+
StringToUnicodeCastSpec->dtypes = s2u_dtypes;
348+
StringToUnicodeCastSpec->slots = s2u_slots;
349+
350+
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
196351
casts[0] = &StringToStringCastSpec;
197352
casts[1] = UnicodeToStringCastSpec;
198-
casts[2] = NULL;
353+
casts[2] = StringToUnicodeCastSpec;
354+
casts[3] = NULL;
199355

200356
return casts;
201357
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ init_string_dtype(void)
228228

229229
free(StringDType_DTypeSpec.casts[1]->dtypes);
230230
free(StringDType_DTypeSpec.casts[1]);
231+
free(StringDType_DTypeSpec.casts[2]->dtypes);
232+
free(StringDType_DTypeSpec.casts[2]);
231233
free(StringDType_DTypeSpec.casts);
232234

233235
return 0;

stringdtype/tests/test_stringdtype.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,26 @@ def test_bad_scalars(data):
6262
[
6363
["this", "is", "an", "array"],
6464
["€", "", "😊"],
65+
["A¢☃€ 😊", " A☃€¢😊", "☃€😊 A¢", "😊☃A¢ €"],
6566
],
6667
)
67-
def test_cast_to_stringdtype(string_list):
68+
def test_unicode_casts(string_list):
6869
arr = np.array(string_list, dtype=np.unicode_).astype(StringDType())
6970
expected = np.array(string_list, dtype=StringDType())
7071
np.testing.assert_array_equal(arr, expected)
7172

72-
73-
@pytest.mark.xfail(reason="Not yet implemented")
74-
def test_cast_to_unicode_safe(string_list):
7573
arr = np.array(string_list, dtype=StringDType())
7674

7775
np.testing.assert_array_equal(
78-
arr.astype("<U3", casting="safe"), np.array(string_list, dtype="<U3")
76+
arr.astype("U8"), np.array(string_list, dtype="U8")
7977
)
80-
81-
# Safe casting should preserve data
82-
with pytest.raises(TypeError):
83-
arr.astype("<U2", casting="safe")
84-
85-
86-
@pytest.mark.xfail(reason="Not yet implemented")
87-
def test_cast_to_unicode_unsafe(string_list):
88-
arr = np.array(string_list, dtype=StringDType())
89-
78+
np.testing.assert_array_equal(arr.astype("U8").astype(StringDType()), arr)
9079
np.testing.assert_array_equal(
91-
arr.astype("<U3", casting="unsafe"), np.array(string_list, dtype="<U3")
80+
arr.astype("U3"), np.array(string_list, dtype="U3")
9281
)
93-
94-
# Unsafe casting: each element is truncated
9582
np.testing.assert_array_equal(
96-
arr.astype("<U2", casting="unsafe"), np.array(string_list, dtype="<U2")
83+
arr.astype("U3").astype(StringDType()),
84+
np.array([s[:3] for s in string_list], dtype=StringDType()),
9785
)
9886

9987

0 commit comments

Comments
 (0)