1
1
#include "casts.h"
2
2
3
+ #include "dtype.h"
4
+
3
5
static NPY_CASTING
4
6
string_to_string_resolve_descriptors (PyObject * NPY_UNUSED (self ),
5
7
PyArray_DTypeMeta * NPY_UNUSED (dtypes [2 ]),
@@ -11,8 +13,7 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
11
13
loop_descrs [0 ] = given_descrs [0 ];
12
14
13
15
if (given_descrs [1 ] == NULL ) {
14
- Py_INCREF (given_descrs [0 ]);
15
- loop_descrs [1 ] = given_descrs [0 ];
16
+ loop_descrs [1 ] = (PyArray_Descr * )new_stringdtype_instance ();
16
17
}
17
18
else {
18
19
Py_INCREF (given_descrs [1 ]);
@@ -64,12 +65,340 @@ PyArrayMethod_Spec StringToStringCastSpec = {
64
65
.slots = s2s_slots ,
65
66
};
66
67
68
+ static NPY_CASTING
69
+ unicode_to_string_resolve_descriptors (PyObject * NPY_UNUSED (self ),
70
+ PyArray_DTypeMeta * NPY_UNUSED (dtypes [2 ]),
71
+ PyArray_Descr * given_descrs [2 ],
72
+ PyArray_Descr * loop_descrs [2 ],
73
+ npy_intp * NPY_UNUSED (view_offset ))
74
+ {
75
+ if (given_descrs [1 ] == NULL ) {
76
+ StringDTypeObject * new = new_stringdtype_instance ();
77
+ if (new == NULL ) {
78
+ return (NPY_CASTING )- 1 ;
79
+ }
80
+ loop_descrs [1 ] = (PyArray_Descr * )new ;
81
+ }
82
+ else {
83
+ Py_INCREF (given_descrs [1 ]);
84
+ loop_descrs [1 ] = given_descrs [1 ];
85
+ }
86
+
87
+ Py_INCREF (given_descrs [0 ]);
88
+ loop_descrs [0 ] = given_descrs [0 ];
89
+
90
+ return NPY_SAFE_CASTING ;
91
+ }
92
+
93
+ // Find the number of bytes, *utf8_bytes*, needed to store the string
94
+ // represented by *codepoints* in UTF-8. The array of *codepoints* is
95
+ // *max_length* long, but may be padded with null codepoints. *num_codepoints*
96
+ // is the number of codepoints that are not trailing null codepoints. Returns
97
+ // 0 on success and -1 when an invalid code point is found.
98
+ static int
99
+ utf8_size (Py_UCS4 * codepoints , long max_length , size_t * num_codepoints ,
100
+ size_t * utf8_bytes )
101
+ {
102
+ size_t ucs4len = max_length ;
103
+
104
+ while (ucs4len > 0 && codepoints [ucs4len - 1 ] == 0 ) {
105
+ ucs4len -- ;
106
+ }
107
+ // ucs4len is now the number of codepoints that aren't trailing nulls.
108
+
109
+ size_t num_bytes = 0 ;
110
+
111
+ for (int i = 0 ; i < ucs4len ; i ++ ) {
112
+ Py_UCS4 code = codepoints [i ];
113
+
114
+ if (code <= 0x7F ) {
115
+ num_bytes += 1 ;
116
+ }
117
+ else if (code <= 0x07FF ) {
118
+ num_bytes += 2 ;
119
+ }
120
+ else if (code <= 0xFFFF ) {
121
+ if ((code >= 0xD800 ) && (code <= 0xDFFF )) {
122
+ // surrogates are invalid UCS4 code points
123
+ return -1 ;
124
+ }
125
+ num_bytes += 3 ;
126
+ }
127
+ else if (code <= 0x10FFFF ) {
128
+ num_bytes += 4 ;
129
+ }
130
+ else {
131
+ // codepoint is outside the valid unicode range
132
+ return -1 ;
133
+ }
134
+ }
135
+
136
+ * num_codepoints = ucs4len ;
137
+ * utf8_bytes = num_bytes ;
138
+
139
+ return 0 ;
140
+ }
141
+
142
+ // Converts UCS4 code point *code* to 4-byte character array *c*. Assumes *c*
143
+ // is a zero-filled 4 byte array and *code* is a valid codepoint and does not
144
+ // do any error checking! Returns the number of bytes in the UTF-8 character.
145
+ static size_t
146
+ ucs4_code_to_utf8_char (const Py_UCS4 code , char * c )
147
+ {
148
+ if (code <= 0x7F ) {
149
+ // 0zzzzzzz -> 0zzzzzzz
150
+ c [0 ] = (char )code ;
151
+ return 1 ;
152
+ }
153
+ else if (code <= 0x07FF ) {
154
+ // 00000yyy yyzzzzzz -> 110yyyyy 10zzzzzz
155
+ c [0 ] = (0xC0 | (code >> 6 ));
156
+ c [1 ] = (0x80 | (code & 0x3F ));
157
+ return 2 ;
158
+ }
159
+ else if (code <= 0xFFFF ) {
160
+ // xxxxyyyy yyzzzzzz -> 110yyyyy 10zzzzzz
161
+ c [0 ] = (0xe0 | (code >> 12 ));
162
+ c [1 ] = (0x80 | ((code >> 6 ) & 0x3f ));
163
+ c [2 ] = (0x80 | (code & 0x3f ));
164
+ return 3 ;
165
+ }
166
+ else {
167
+ // 00wwwxx xxxxyyyy yyzzzzzz -> 11110www 10xxxxxx 10yyyyyy 10zzzzzz
168
+ c [0 ] = (0xf0 | (code >> 18 ));
169
+ c [1 ] = (0x80 | ((code >> 12 ) & 0x3f ));
170
+ c [2 ] = (0x80 | ((code >> 6 ) & 0x3f ));
171
+ c [3 ] = (0x80 | (code & 0x3f ));
172
+ return 4 ;
173
+ }
174
+ }
175
+
176
+ static int
177
+ unicode_to_string (PyArrayMethod_Context * context , char * const data [],
178
+ npy_intp const dimensions [], npy_intp const strides [],
179
+ NpyAuxData * NPY_UNUSED (auxdata ))
180
+ {
181
+ PyArray_Descr * * descrs = context -> descriptors ;
182
+ long max_in_size = (descrs [0 ]-> elsize ) / 4 ;
183
+
184
+ npy_intp N = dimensions [0 ];
185
+ Py_UCS4 * in = (Py_UCS4 * )data [0 ];
186
+ char * * out = (char * * )data [1 ];
187
+
188
+ // 4 bytes per UCS4 character
189
+ npy_intp in_stride = strides [0 ] / 4 ;
190
+ // strides are in bytes but pointer offsets are in pointer widths, so
191
+ // divide by the element size (one pointer width) to get the pointer offset
192
+ npy_intp out_stride = strides [1 ] / context -> descriptors [1 ]-> elsize ;
193
+
194
+ while (N -- ) {
195
+ size_t out_num_bytes = 0 ;
196
+ size_t num_codepoints = 0 ;
197
+ if (utf8_size (in , max_in_size , & num_codepoints , & out_num_bytes ) ==
198
+ -1 ) {
199
+ // invalid codepoint found so acquire GIL, set error, return
200
+ PyGILState_STATE gstate ;
201
+ gstate = PyGILState_Ensure ();
202
+ PyErr_SetString (PyExc_TypeError ,
203
+ "Invalid unicode code point found" );
204
+ PyGILState_Release (gstate );
205
+ return -1 ;
206
+ }
207
+ // one extra byte for null terminator
208
+ char * out_buf = malloc ((out_num_bytes + 1 ) * sizeof (char ));
209
+ for (int i = 0 ; i < num_codepoints ; i ++ ) {
210
+ // get code point
211
+ Py_UCS4 code = in [i ];
212
+
213
+ // will be filled with UTF-8 bytes
214
+ char utf8_c [4 ] = {0 };
215
+
216
+ // we already checked for invalid code points above,
217
+ // so no need to do error checking here
218
+ size_t num_bytes = ucs4_code_to_utf8_char (code , utf8_c );
219
+
220
+ // copy utf8_c into out_buf
221
+ strncpy (out_buf , utf8_c , num_bytes );
222
+
223
+ // increment out_buf by the size of the character
224
+ out_buf += num_bytes ;
225
+ }
226
+
227
+ // reset out_buf to the beginning of the string
228
+ out_buf -= out_num_bytes ;
229
+
230
+ // pad string with null character
231
+ out_buf [out_num_bytes ] = '\0' ;
232
+
233
+ // set out to the address of the beginning of the string
234
+ out [0 ] = out_buf ;
235
+
236
+ in += in_stride ;
237
+ out += out_stride ;
238
+ }
239
+
240
+ return 0 ;
241
+ }
242
+
243
+ static PyType_Slot u2s_slots [] = {
244
+ {NPY_METH_resolve_descriptors , & unicode_to_string_resolve_descriptors },
245
+ {NPY_METH_strided_loop , & unicode_to_string },
246
+ {0 , NULL }};
247
+
248
+ static char * u2s_name = "cast_Unicode_to_StringDType" ;
249
+
250
+ static NPY_CASTING
251
+ string_to_unicode_resolve_descriptors (PyObject * NPY_UNUSED (self ),
252
+ PyArray_DTypeMeta * NPY_UNUSED (dtypes [2 ]),
253
+ PyArray_Descr * given_descrs [2 ],
254
+ PyArray_Descr * loop_descrs [2 ],
255
+ npy_intp * NPY_UNUSED (view_offset ))
256
+ {
257
+ if (given_descrs [1 ] == NULL ) {
258
+ // currently there's no way to determine the correct output
259
+ // size, so set an error and bail
260
+ PyErr_SetString (
261
+ PyExc_TypeError ,
262
+ "Casting from StringDType to a fixed-width dtype with an "
263
+ "unspecified size is not currently supported, specify "
264
+ "an explicit size for the output dtype instead." );
265
+ return (NPY_CASTING )- 1 ;
266
+ }
267
+ else {
268
+ Py_INCREF (given_descrs [1 ]);
269
+ loop_descrs [1 ] = given_descrs [1 ];
270
+ }
271
+
272
+ Py_INCREF (given_descrs [0 ]);
273
+ loop_descrs [0 ] = given_descrs [0 ];
274
+
275
+ return NPY_UNSAFE_CASTING ;
276
+ }
277
+
278
+ // Given UTF-8 bytes in *c*, sets *code* to the corresponding unicode
279
+ // codepoint for the next character, returning the size of the character in
280
+ // bytes. Does not do any validation or error checking: assumes *c* is valid
281
+ // utf-8
282
+ static size_t
283
+ utf8_char_to_ucs4_code (unsigned char * c , Py_UCS4 * code )
284
+ {
285
+ if (c [0 ] <= 0x7F ) {
286
+ // 0zzzzzzz -> 0zzzzzzz
287
+ * code = (Py_UCS4 )(c [0 ]);
288
+ return 1 ;
289
+ }
290
+ else if (c [0 ] <= 0xDF ) {
291
+ // 110yyyyy 10zzzzzz -> 00000yyy yyzzzzzz
292
+ * code = (Py_UCS4 )(((c [0 ] << 6 ) + c [1 ]) - ((0xC0 << 6 ) + 0x80 ));
293
+ return 2 ;
294
+ }
295
+ else if (c [0 ] <= 0xEF ) {
296
+ // 1110xxxx 10yyyyyy 10zzzzzz -> xxxxyyyy yyzzzzzz
297
+ * code = (Py_UCS4 )(((c [0 ] << 12 ) + (c [1 ] << 6 ) + c [2 ]) -
298
+ ((0xE0 << 12 ) + (0x80 << 6 ) + 0x80 ));
299
+ return 3 ;
300
+ }
301
+ else {
302
+ // 11110www 10xxxxxx 10yyyyyy 10zzzzzz -> 000wwwxx xxxxyyyy yyzzzzzz
303
+ * code = (Py_UCS4 )(((c [0 ] << 18 ) + (c [1 ] << 12 ) + (c [2 ] << 6 ) + c [3 ]) -
304
+ ((0xF0 << 18 ) + (0x80 << 12 ) + (0x80 << 6 ) + 0x80 ));
305
+ return 4 ;
306
+ }
307
+ }
308
+
309
+ static int
310
+ string_to_unicode (PyArrayMethod_Context * context , char * const data [],
311
+ npy_intp const dimensions [], npy_intp const strides [],
312
+ NpyAuxData * NPY_UNUSED (auxdata ))
313
+ {
314
+ npy_intp N = dimensions [0 ];
315
+ char * * in = (char * * )data [0 ];
316
+ Py_UCS4 * out = (Py_UCS4 * )data [1 ];
317
+ // strides are in bytes but pointer offsets are in pointer widths, so
318
+ // divide by the element size (one pointer width) to get the pointer offset
319
+ npy_intp in_stride = strides [0 ] / context -> descriptors [0 ]-> elsize ;
320
+ // 4 bytes per UCS4 character
321
+ npy_intp out_stride = strides [1 ] / 4 ;
322
+ // max number of 4 byte UCS4 characters that can fit in the output
323
+ long max_out_size = (context -> descriptors [1 ]-> elsize ) / 4 ;
324
+
325
+ while (N -- ) {
326
+ unsigned char * this_string = (unsigned char * )* in ;
327
+
328
+ for (int i = 0 ; i < max_out_size ; i ++ ) {
329
+ Py_UCS4 code ;
330
+
331
+ // get code point for character this_string is currently pointing
332
+ // too
333
+ size_t num_bytes = utf8_char_to_ucs4_code (this_string , & code );
334
+
335
+ // move to next character
336
+ this_string += num_bytes ;
337
+
338
+ // set output codepoint
339
+ out [i ] = code ;
340
+
341
+ // check if this is the null terminator
342
+ if (code == 0 ) {
343
+ // fill all remaining characters (if any) with zero
344
+ for (int j = i + 1 ; j < max_out_size ; j ++ ) {
345
+ out [j ] = 0 ;
346
+ }
347
+ break ;
348
+ }
349
+ }
350
+ in += in_stride ;
351
+ out += out_stride ;
352
+ }
353
+
354
+ return 0 ;
355
+ }
356
+
357
+ static PyType_Slot s2u_slots [] = {
358
+ {NPY_METH_resolve_descriptors , & string_to_unicode_resolve_descriptors },
359
+ {NPY_METH_strided_loop , & string_to_unicode },
360
+ {0 , NULL }};
361
+
362
+ static char * s2u_name = "cast_StringDType_to_Unicode" ;
363
+
67
364
PyArrayMethod_Spec * *
68
365
get_casts (void )
69
366
{
70
- PyArrayMethod_Spec * * casts = malloc (2 * sizeof (PyArrayMethod_Spec * ));
367
+ PyArray_DTypeMeta * * u2s_dtypes = malloc (2 * sizeof (PyArray_DTypeMeta * ));
368
+ u2s_dtypes [0 ] = & PyArray_UnicodeDType ;
369
+ u2s_dtypes [1 ] = NULL ;
370
+
371
+ PyArrayMethod_Spec * UnicodeToStringCastSpec =
372
+ malloc (sizeof (PyArrayMethod_Spec ));
373
+
374
+ UnicodeToStringCastSpec -> name = u2s_name ;
375
+ UnicodeToStringCastSpec -> nin = 1 ;
376
+ UnicodeToStringCastSpec -> nout = 1 ;
377
+ UnicodeToStringCastSpec -> casting = NPY_SAFE_CASTING ;
378
+ UnicodeToStringCastSpec -> flags = NPY_METH_NO_FLOATINGPOINT_ERRORS ;
379
+ UnicodeToStringCastSpec -> dtypes = u2s_dtypes ;
380
+ UnicodeToStringCastSpec -> slots = u2s_slots ;
381
+
382
+ PyArray_DTypeMeta * * s2u_dtypes = malloc (2 * sizeof (PyArray_DTypeMeta * ));
383
+ s2u_dtypes [0 ] = NULL ;
384
+ s2u_dtypes [1 ] = & PyArray_UnicodeDType ;
385
+
386
+ PyArrayMethod_Spec * StringToUnicodeCastSpec =
387
+ malloc (sizeof (PyArrayMethod_Spec ));
388
+
389
+ StringToUnicodeCastSpec -> name = s2u_name ;
390
+ StringToUnicodeCastSpec -> nin = 1 ;
391
+ StringToUnicodeCastSpec -> nout = 1 ;
392
+ StringToUnicodeCastSpec -> casting = NPY_SAFE_CASTING ;
393
+ StringToUnicodeCastSpec -> flags = NPY_METH_NO_FLOATINGPOINT_ERRORS ;
394
+ StringToUnicodeCastSpec -> dtypes = s2u_dtypes ;
395
+ StringToUnicodeCastSpec -> slots = s2u_slots ;
396
+
397
+ PyArrayMethod_Spec * * casts = malloc (4 * sizeof (PyArrayMethod_Spec * ));
71
398
casts [0 ] = & StringToStringCastSpec ;
72
- casts [1 ] = NULL ;
399
+ casts [1 ] = UnicodeToStringCastSpec ;
400
+ casts [2 ] = StringToUnicodeCastSpec ;
401
+ casts [3 ] = NULL ;
73
402
74
403
return casts ;
75
404
}
0 commit comments