Skip to content

Commit 07c5087

Browse files
authored
Merge pull request #25 from ngoldbaum/unicode-string-casts
Unicode <-> string casts and the equal ufunc
2 parents 0dabe4e + bd59d33 commit 07c5087

File tree

7 files changed

+411
-195
lines changed

7 files changed

+411
-195
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ repos:
9898
- id: isort
9999
name: isort (pyi)
100100
types: [pyi]
101+
- repo: https://github.com/psf/black
102+
rev: 22.12.0
103+
hooks:
104+
- id: black

stringdtype/meson.build

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ srcs = [
2626
'stringdtype/src/casts.h',
2727
'stringdtype/src/dtype.c',
2828
'stringdtype/src/main.c',
29-
# 'stringdtype/src/umath.c',
30-
# 'stringdtype/src/umath.h',
29+
'stringdtype/src/umath.c',
30+
'stringdtype/src/umath.h',
3131
]
3232

3333
py.install_sources(

stringdtype/stringdtype/src/casts.c

Lines changed: 333 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "casts.h"
22

3+
#include "dtype.h"
4+
35
static NPY_CASTING
46
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
57
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -11,8 +13,7 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
1113
loop_descrs[0] = given_descrs[0];
1214

1315
if (given_descrs[1] == NULL) {
14-
Py_INCREF(given_descrs[0]);
15-
loop_descrs[1] = given_descrs[0];
16+
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance();
1617
}
1718
else {
1819
Py_INCREF(given_descrs[1]);
@@ -64,12 +65,340 @@ PyArrayMethod_Spec StringToStringCastSpec = {
6465
.slots = s2s_slots,
6566
};
6667

68+
static NPY_CASTING
69+
unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
70+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
71+
PyArray_Descr *given_descrs[2],
72+
PyArray_Descr *loop_descrs[2],
73+
npy_intp *NPY_UNUSED(view_offset))
74+
{
75+
if (given_descrs[1] == NULL) {
76+
StringDTypeObject *new = new_stringdtype_instance();
77+
if (new == NULL) {
78+
return (NPY_CASTING)-1;
79+
}
80+
loop_descrs[1] = (PyArray_Descr *)new;
81+
}
82+
else {
83+
Py_INCREF(given_descrs[1]);
84+
loop_descrs[1] = given_descrs[1];
85+
}
86+
87+
Py_INCREF(given_descrs[0]);
88+
loop_descrs[0] = given_descrs[0];
89+
90+
return NPY_SAFE_CASTING;
91+
}
92+
93+
// Find the number of bytes, *utf8_bytes*, needed to store the string
94+
// represented by *codepoints* in UTF-8. The array of *codepoints* is
95+
// *max_length* long, but may be padded with null codepoints. *num_codepoints*
96+
// is the number of codepoints that are not trailing null codepoints. Returns
97+
// 0 on success and -1 when an invalid code point is found.
98+
static int
99+
utf8_size(Py_UCS4 *codepoints, long max_length, size_t *num_codepoints,
100+
size_t *utf8_bytes)
101+
{
102+
size_t ucs4len = max_length;
103+
104+
while (ucs4len > 0 && codepoints[ucs4len - 1] == 0) {
105+
ucs4len--;
106+
}
107+
// ucs4len is now the number of codepoints that aren't trailing nulls.
108+
109+
size_t num_bytes = 0;
110+
111+
for (int i = 0; i < ucs4len; i++) {
112+
Py_UCS4 code = codepoints[i];
113+
114+
if (code <= 0x7F) {
115+
num_bytes += 1;
116+
}
117+
else if (code <= 0x07FF) {
118+
num_bytes += 2;
119+
}
120+
else if (code <= 0xFFFF) {
121+
if ((code >= 0xD800) && (code <= 0xDFFF)) {
122+
// surrogates are invalid UCS4 code points
123+
return -1;
124+
}
125+
num_bytes += 3;
126+
}
127+
else if (code <= 0x10FFFF) {
128+
num_bytes += 4;
129+
}
130+
else {
131+
// codepoint is outside the valid unicode range
132+
return -1;
133+
}
134+
}
135+
136+
*num_codepoints = ucs4len;
137+
*utf8_bytes = num_bytes;
138+
139+
return 0;
140+
}
141+
142+
// Converts UCS4 code point *code* to 4-byte character array *c*. Assumes *c*
143+
// is a zero-filled 4 byte array and *code* is a valid codepoint and does not
144+
// do any error checking! Returns the number of bytes in the UTF-8 character.
145+
static size_t
146+
ucs4_code_to_utf8_char(const Py_UCS4 code, char *c)
147+
{
148+
if (code <= 0x7F) {
149+
// 0zzzzzzz -> 0zzzzzzz
150+
c[0] = (char)code;
151+
return 1;
152+
}
153+
else if (code <= 0x07FF) {
154+
// 00000yyy yyzzzzzz -> 110yyyyy 10zzzzzz
155+
c[0] = (0xC0 | (code >> 6));
156+
c[1] = (0x80 | (code & 0x3F));
157+
return 2;
158+
}
159+
else if (code <= 0xFFFF) {
160+
// xxxxyyyy yyzzzzzz -> 110yyyyy 10zzzzzz
161+
c[0] = (0xe0 | (code >> 12));
162+
c[1] = (0x80 | ((code >> 6) & 0x3f));
163+
c[2] = (0x80 | (code & 0x3f));
164+
return 3;
165+
}
166+
else {
167+
// 00wwwxx xxxxyyyy yyzzzzzz -> 11110www 10xxxxxx 10yyyyyy 10zzzzzz
168+
c[0] = (0xf0 | (code >> 18));
169+
c[1] = (0x80 | ((code >> 12) & 0x3f));
170+
c[2] = (0x80 | ((code >> 6) & 0x3f));
171+
c[3] = (0x80 | (code & 0x3f));
172+
return 4;
173+
}
174+
}
175+
176+
static int
177+
unicode_to_string(PyArrayMethod_Context *context, char *const data[],
178+
npy_intp const dimensions[], npy_intp const strides[],
179+
NpyAuxData *NPY_UNUSED(auxdata))
180+
{
181+
PyArray_Descr **descrs = context->descriptors;
182+
long max_in_size = (descrs[0]->elsize) / 4;
183+
184+
npy_intp N = dimensions[0];
185+
Py_UCS4 *in = (Py_UCS4 *)data[0];
186+
char **out = (char **)data[1];
187+
188+
// 4 bytes per UCS4 character
189+
npy_intp in_stride = strides[0] / 4;
190+
// strides are in bytes but pointer offsets are in pointer widths, so
191+
// divide by the element size (one pointer width) to get the pointer offset
192+
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
193+
194+
while (N--) {
195+
size_t out_num_bytes = 0;
196+
size_t num_codepoints = 0;
197+
if (utf8_size(in, max_in_size, &num_codepoints, &out_num_bytes) ==
198+
-1) {
199+
// invalid codepoint found so acquire GIL, set error, return
200+
PyGILState_STATE gstate;
201+
gstate = PyGILState_Ensure();
202+
PyErr_SetString(PyExc_TypeError,
203+
"Invalid unicode code point found");
204+
PyGILState_Release(gstate);
205+
return -1;
206+
}
207+
// one extra byte for null terminator
208+
char *out_buf = malloc((out_num_bytes + 1) * sizeof(char));
209+
for (int i = 0; i < num_codepoints; i++) {
210+
// get code point
211+
Py_UCS4 code = in[i];
212+
213+
// will be filled with UTF-8 bytes
214+
char utf8_c[4] = {0};
215+
216+
// we already checked for invalid code points above,
217+
// so no need to do error checking here
218+
size_t num_bytes = ucs4_code_to_utf8_char(code, utf8_c);
219+
220+
// copy utf8_c into out_buf
221+
strncpy(out_buf, utf8_c, num_bytes);
222+
223+
// increment out_buf by the size of the character
224+
out_buf += num_bytes;
225+
}
226+
227+
// reset out_buf to the beginning of the string
228+
out_buf -= out_num_bytes;
229+
230+
// pad string with null character
231+
out_buf[out_num_bytes] = '\0';
232+
233+
// set out to the address of the beginning of the string
234+
out[0] = out_buf;
235+
236+
in += in_stride;
237+
out += out_stride;
238+
}
239+
240+
return 0;
241+
}
242+
243+
static PyType_Slot u2s_slots[] = {
244+
{NPY_METH_resolve_descriptors, &unicode_to_string_resolve_descriptors},
245+
{NPY_METH_strided_loop, &unicode_to_string},
246+
{0, NULL}};
247+
248+
static char *u2s_name = "cast_Unicode_to_StringDType";
249+
250+
static NPY_CASTING
251+
string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
252+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
253+
PyArray_Descr *given_descrs[2],
254+
PyArray_Descr *loop_descrs[2],
255+
npy_intp *NPY_UNUSED(view_offset))
256+
{
257+
if (given_descrs[1] == NULL) {
258+
// currently there's no way to determine the correct output
259+
// size, so set an error and bail
260+
PyErr_SetString(
261+
PyExc_TypeError,
262+
"Casting from StringDType to a fixed-width dtype with an "
263+
"unspecified size is not currently supported, specify "
264+
"an explicit size for the output dtype instead.");
265+
return (NPY_CASTING)-1;
266+
}
267+
else {
268+
Py_INCREF(given_descrs[1]);
269+
loop_descrs[1] = given_descrs[1];
270+
}
271+
272+
Py_INCREF(given_descrs[0]);
273+
loop_descrs[0] = given_descrs[0];
274+
275+
return NPY_UNSAFE_CASTING;
276+
}
277+
278+
// Given UTF-8 bytes in *c*, sets *code* to the corresponding unicode
279+
// codepoint for the next character, returning the size of the character in
280+
// bytes. Does not do any validation or error checking: assumes *c* is valid
281+
// utf-8
282+
static size_t
283+
utf8_char_to_ucs4_code(unsigned char *c, Py_UCS4 *code)
284+
{
285+
if (c[0] <= 0x7F) {
286+
// 0zzzzzzz -> 0zzzzzzz
287+
*code = (Py_UCS4)(c[0]);
288+
return 1;
289+
}
290+
else if (c[0] <= 0xDF) {
291+
// 110yyyyy 10zzzzzz -> 00000yyy yyzzzzzz
292+
*code = (Py_UCS4)(((c[0] << 6) + c[1]) - ((0xC0 << 6) + 0x80));
293+
return 2;
294+
}
295+
else if (c[0] <= 0xEF) {
296+
// 1110xxxx 10yyyyyy 10zzzzzz -> xxxxyyyy yyzzzzzz
297+
*code = (Py_UCS4)(((c[0] << 12) + (c[1] << 6) + c[2]) -
298+
((0xE0 << 12) + (0x80 << 6) + 0x80));
299+
return 3;
300+
}
301+
else {
302+
// 11110www 10xxxxxx 10yyyyyy 10zzzzzz -> 000wwwxx xxxxyyyy yyzzzzzz
303+
*code = (Py_UCS4)(((c[0] << 18) + (c[1] << 12) + (c[2] << 6) + c[3]) -
304+
((0xF0 << 18) + (0x80 << 12) + (0x80 << 6) + 0x80));
305+
return 4;
306+
}
307+
}
308+
309+
static int
310+
string_to_unicode(PyArrayMethod_Context *context, char *const data[],
311+
npy_intp const dimensions[], npy_intp const strides[],
312+
NpyAuxData *NPY_UNUSED(auxdata))
313+
{
314+
npy_intp N = dimensions[0];
315+
char **in = (char **)data[0];
316+
Py_UCS4 *out = (Py_UCS4 *)data[1];
317+
// strides are in bytes but pointer offsets are in pointer widths, so
318+
// divide by the element size (one pointer width) to get the pointer offset
319+
npy_intp in_stride = strides[0] / context->descriptors[0]->elsize;
320+
// 4 bytes per UCS4 character
321+
npy_intp out_stride = strides[1] / 4;
322+
// max number of 4 byte UCS4 characters that can fit in the output
323+
long max_out_size = (context->descriptors[1]->elsize) / 4;
324+
325+
while (N--) {
326+
unsigned char *this_string = (unsigned char *)*in;
327+
328+
for (int i = 0; i < max_out_size; i++) {
329+
Py_UCS4 code;
330+
331+
// get code point for character this_string is currently pointing
332+
// too
333+
size_t num_bytes = utf8_char_to_ucs4_code(this_string, &code);
334+
335+
// move to next character
336+
this_string += num_bytes;
337+
338+
// set output codepoint
339+
out[i] = code;
340+
341+
// check if this is the null terminator
342+
if (code == 0) {
343+
// fill all remaining characters (if any) with zero
344+
for (int j = i + 1; j < max_out_size; j++) {
345+
out[j] = 0;
346+
}
347+
break;
348+
}
349+
}
350+
in += in_stride;
351+
out += out_stride;
352+
}
353+
354+
return 0;
355+
}
356+
357+
static PyType_Slot s2u_slots[] = {
358+
{NPY_METH_resolve_descriptors, &string_to_unicode_resolve_descriptors},
359+
{NPY_METH_strided_loop, &string_to_unicode},
360+
{0, NULL}};
361+
362+
static char *s2u_name = "cast_StringDType_to_Unicode";
363+
67364
PyArrayMethod_Spec **
68365
get_casts(void)
69366
{
70-
PyArrayMethod_Spec **casts = malloc(2 * sizeof(PyArrayMethod_Spec *));
367+
PyArray_DTypeMeta **u2s_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
368+
u2s_dtypes[0] = &PyArray_UnicodeDType;
369+
u2s_dtypes[1] = NULL;
370+
371+
PyArrayMethod_Spec *UnicodeToStringCastSpec =
372+
malloc(sizeof(PyArrayMethod_Spec));
373+
374+
UnicodeToStringCastSpec->name = u2s_name;
375+
UnicodeToStringCastSpec->nin = 1;
376+
UnicodeToStringCastSpec->nout = 1;
377+
UnicodeToStringCastSpec->casting = NPY_SAFE_CASTING;
378+
UnicodeToStringCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
379+
UnicodeToStringCastSpec->dtypes = u2s_dtypes;
380+
UnicodeToStringCastSpec->slots = u2s_slots;
381+
382+
PyArray_DTypeMeta **s2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
383+
s2u_dtypes[0] = NULL;
384+
s2u_dtypes[1] = &PyArray_UnicodeDType;
385+
386+
PyArrayMethod_Spec *StringToUnicodeCastSpec =
387+
malloc(sizeof(PyArrayMethod_Spec));
388+
389+
StringToUnicodeCastSpec->name = s2u_name;
390+
StringToUnicodeCastSpec->nin = 1;
391+
StringToUnicodeCastSpec->nout = 1;
392+
StringToUnicodeCastSpec->casting = NPY_SAFE_CASTING;
393+
StringToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
394+
StringToUnicodeCastSpec->dtypes = s2u_dtypes;
395+
StringToUnicodeCastSpec->slots = s2u_slots;
396+
397+
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
71398
casts[0] = &StringToStringCastSpec;
72-
casts[1] = NULL;
399+
casts[1] = UnicodeToStringCastSpec;
400+
casts[2] = StringToUnicodeCastSpec;
401+
casts[3] = NULL;
73402

74403
return casts;
75404
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,11 @@ init_string_dtype(void)
226226

227227
StringDType.singleton = singleton;
228228

229+
free(StringDType_DTypeSpec.casts[1]->dtypes);
230+
free(StringDType_DTypeSpec.casts[1]);
231+
free(StringDType_DTypeSpec.casts[2]->dtypes);
232+
free(StringDType_DTypeSpec.casts[2]);
233+
free(StringDType_DTypeSpec.casts);
234+
229235
return 0;
230236
}

0 commit comments

Comments
 (0)