Skip to content

Commit bd59d33

Browse files
committed
respond to review comments
1 parent cabab7f commit bd59d33

File tree

2 files changed

+84
-42
lines changed

2 files changed

+84
-42
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,3 @@ repos:
102102
rev: 22.12.0
103103
hooks:
104104
- id: black
105-
# It is recommended to specify the latest version of Python
106-
# supported by your project here, or alternatively use
107-
# pre-commit's default_language_version, see
108-
# https://pre-commit.com/#top_level-default_language_version
109-
language_version: python3.9

stringdtype/stringdtype/src/casts.c

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,77 @@ unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
7272
PyArray_Descr *loop_descrs[2],
7373
npy_intp *NPY_UNUSED(view_offset))
7474
{
75-
Py_INCREF(given_descrs[0]);
76-
loop_descrs[0] = given_descrs[0];
77-
7875
if (given_descrs[1] == NULL) {
79-
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance();
76+
StringDTypeObject *new = new_stringdtype_instance();
77+
if (new == NULL) {
78+
return (NPY_CASTING)-1;
79+
}
80+
loop_descrs[1] = (PyArray_Descr *)new;
8081
}
8182
else {
8283
Py_INCREF(given_descrs[1]);
8384
loop_descrs[1] = given_descrs[1];
8485
}
8586

87+
Py_INCREF(given_descrs[0]);
88+
loop_descrs[0] = given_descrs[0];
89+
8690
return NPY_SAFE_CASTING;
8791
}
8892

89-
// converts UCS4 code point to 4-byte char* assumes in is a zero-filled 4 byte
90-
// array returns -1 if the code point is not a valid unicode code point,
91-
// returns the number of bytes in the UTF-8 character on success
93+
// Find the number of bytes, *utf8_bytes*, needed to store the string
94+
// represented by *codepoints* in UTF-8. The array of *codepoints* is
95+
// *max_length* long, but may be padded with null codepoints. *num_codepoints*
96+
// is the number of codepoints that are not trailing null codepoints. Returns
97+
// 0 on success and -1 when an invalid code point is found.
9298
static int
99+
utf8_size(Py_UCS4 *codepoints, long max_length, size_t *num_codepoints,
100+
size_t *utf8_bytes)
101+
{
102+
size_t ucs4len = max_length;
103+
104+
while (ucs4len > 0 && codepoints[ucs4len - 1] == 0) {
105+
ucs4len--;
106+
}
107+
// ucs4len is now the number of codepoints that aren't trailing nulls.
108+
109+
size_t num_bytes = 0;
110+
111+
for (int i = 0; i < ucs4len; i++) {
112+
Py_UCS4 code = codepoints[i];
113+
114+
if (code <= 0x7F) {
115+
num_bytes += 1;
116+
}
117+
else if (code <= 0x07FF) {
118+
num_bytes += 2;
119+
}
120+
else if (code <= 0xFFFF) {
121+
if ((code >= 0xD800) && (code <= 0xDFFF)) {
122+
// surrogates are invalid UCS4 code points
123+
return -1;
124+
}
125+
num_bytes += 3;
126+
}
127+
else if (code <= 0x10FFFF) {
128+
num_bytes += 4;
129+
}
130+
else {
131+
// codepoint is outside the valid unicode range
132+
return -1;
133+
}
134+
}
135+
136+
*num_codepoints = ucs4len;
137+
*utf8_bytes = num_bytes;
138+
139+
return 0;
140+
}
141+
142+
// Converts UCS4 code point *code* to 4-byte character array *c*. Assumes *c*
143+
// is a zero-filled 4 byte array and *code* is a valid codepoint and does not
144+
// do any error checking! Returns the number of bytes in the UTF-8 character.
145+
static size_t
93146
ucs4_code_to_utf8_char(const Py_UCS4 code, char *c)
94147
{
95148
if (code <= 0x7F) {
@@ -110,15 +163,14 @@ ucs4_code_to_utf8_char(const Py_UCS4 code, char *c)
110163
c[2] = (0x80 | (code & 0x3f));
111164
return 3;
112165
}
113-
else if (code <= 0x10FFFF) {
166+
else {
114167
// 00wwwxx xxxxyyyy yyzzzzzz -> 11110www 10xxxxxx 10yyyyyy 10zzzzzz
115168
c[0] = (0xf0 | (code >> 18));
116169
c[1] = (0x80 | ((code >> 12) & 0x3f));
117170
c[2] = (0x80 | ((code >> 6) & 0x3f));
118171
c[3] = (0x80 | (code & 0x3f));
119172
return 4;
120173
}
121-
return -1;
122174
}
123175

124176
static int
@@ -127,7 +179,7 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
127179
NpyAuxData *NPY_UNUSED(auxdata))
128180
{
129181
PyArray_Descr **descrs = context->descriptors;
130-
long in_size = (descrs[0]->elsize) / 4;
182+
long max_in_size = (descrs[0]->elsize) / 4;
131183

132184
npy_intp N = dimensions[0];
133185
Py_UCS4 *in = (Py_UCS4 *)data[0];
@@ -140,32 +192,30 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
140192
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
141193

142194
while (N--) {
143-
// pessimistically allocate 4 bytes per allowed character
144-
// plus one byte for the null terminator
145-
char *out_buf = malloc((in_size * 4 + 1) * sizeof(char));
146195
size_t out_num_bytes = 0;
147-
for (int i = 0; i < in_size; i++) {
196+
size_t num_codepoints = 0;
197+
if (utf8_size(in, max_in_size, &num_codepoints, &out_num_bytes) ==
198+
-1) {
199+
// invalid codepoint found so acquire GIL, set error, return
200+
PyGILState_STATE gstate;
201+
gstate = PyGILState_Ensure();
202+
PyErr_SetString(PyExc_TypeError,
203+
"Invalid unicode code point found");
204+
PyGILState_Release(gstate);
205+
return -1;
206+
}
207+
// one extra byte for null terminator
208+
char *out_buf = malloc((out_num_bytes + 1) * sizeof(char));
209+
for (int i = 0; i < num_codepoints; i++) {
148210
// get code point
149211
Py_UCS4 code = in[i];
150212

151-
if (code == 0) {
152-
break;
153-
}
154-
155-
// convert codepoint to UTF8 bytes
213+
// will be filled with UTF-8 bytes
156214
char utf8_c[4] = {0};
215+
216+
// we already checked for invalid code points above,
217+
// so no need to do error checking here
157218
size_t num_bytes = ucs4_code_to_utf8_char(code, utf8_c);
158-
out_num_bytes += num_bytes;
159-
160-
if (num_bytes == -1) {
161-
// acquire GIL, set error, return
162-
PyGILState_STATE gstate;
163-
gstate = PyGILState_Ensure();
164-
PyErr_SetString(PyExc_TypeError,
165-
"Invalid unicode code point found");
166-
PyGILState_Release(gstate);
167-
return -1;
168-
}
169219

170220
// copy utf8_c into out_buf
171221
strncpy(out_buf, utf8_c, num_bytes);
@@ -180,9 +230,6 @@ unicode_to_string(PyArrayMethod_Context *context, char *const data[],
180230
// pad string with null character
181231
out_buf[out_num_bytes] = '\0';
182232

183-
// resize out_buf now that we know the real size
184-
out_buf = realloc(out_buf, out_num_bytes + 1);
185-
186233
// set out to the address of the beginning of the string
187234
out[0] = out_buf;
188235

@@ -207,9 +254,6 @@ string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
207254
PyArray_Descr *loop_descrs[2],
208255
npy_intp *NPY_UNUSED(view_offset))
209256
{
210-
Py_INCREF(given_descrs[0]);
211-
loop_descrs[0] = given_descrs[0];
212-
213257
if (given_descrs[1] == NULL) {
214258
// currently there's no way to determine the correct output
215259
// size, so set an error and bail
@@ -225,10 +269,13 @@ string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
225269
loop_descrs[1] = given_descrs[1];
226270
}
227271

272+
Py_INCREF(given_descrs[0]);
273+
loop_descrs[0] = given_descrs[0];
274+
228275
return NPY_UNSAFE_CASTING;
229276
}
230277

231-
// Given UTF-8 bytes in *c*, sets *codepoint* to the corresponding unicode
278+
// Given UTF-8 bytes in *c*, sets *code* to the corresponding unicode
232279
// codepoint for the next character, returning the size of the character in
233280
// bytes. Does not do any validation or error checking: assumes *c* is valid
234281
// utf-8

0 commit comments

Comments
 (0)