Skip to content

Commit bc3f8be

Browse files
committed
add unicode_ to StringDType cast
1 parent 7c17d57 commit bc3f8be

File tree

3 files changed

+174
-39
lines changed

3 files changed

+174
-39
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 136 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#include "casts.h"
22

3+
#include "dtype.h"
4+
35
static NPY_CASTING
4-
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
5-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
6-
PyArray_Descr *given_descrs[2],
7-
PyArray_Descr *loop_descrs[2],
8-
npy_intp *NPY_UNUSED(view_offset))
6+
string_resolve_descriptors(PyObject *NPY_UNUSED(self),
7+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
8+
PyArray_Descr *given_descrs[2],
9+
PyArray_Descr *loop_descrs[2],
10+
npy_intp *NPY_UNUSED(view_offset))
911
{
1012
Py_INCREF(given_descrs[0]);
1113
loop_descrs[0] = given_descrs[0];
1214

1315
if (given_descrs[1] == NULL) {
14-
Py_INCREF(given_descrs[0]);
15-
loop_descrs[1] = given_descrs[0];
16+
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance();
1617
}
1718
else {
1819
Py_INCREF(given_descrs[1]);
@@ -49,7 +50,7 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
4950
static PyArray_DTypeMeta *s2s_dtypes[2] = {NULL, NULL};
5051

5152
static PyType_Slot s2s_slots[] = {
52-
{NPY_METH_resolve_descriptors, &string_to_string_resolve_descriptors},
53+
{NPY_METH_resolve_descriptors, &string_resolve_descriptors},
5354
{NPY_METH_strided_loop, &string_to_string},
5455
{NPY_METH_unaligned_strided_loop, &string_to_string},
5556
{0, NULL}};
@@ -64,12 +65,137 @@ PyArrayMethod_Spec StringToStringCastSpec = {
6465
.slots = s2s_slots,
6566
};
6667

68+
// converts UCS4 code point to 4-byte char* assumes in is a zero-filled 4 byte
69+
// array returns -1 if the code point is not a valid unicode code point, the
70+
// number of bytes in the in the UTF-8 character on success
71+
static int
72+
ucs4_to_utf8_char(const Py_UCS4 code, char *in)
73+
{
74+
if (code <= 0x7F) {
75+
// ASCII
76+
in[0] = (char)code;
77+
return 1;
78+
}
79+
else if (code <= 0x07FF) {
80+
in[0] = (0xc0 | (code >> 6));
81+
in[1] = (0x80 | (code & 0x3f));
82+
return 2;
83+
}
84+
else if (code <= 0xFFFF) {
85+
in[0] = (0xe0 | (code >> 12));
86+
in[1] = (0x80 | ((code >> 6) & 0x3f));
87+
in[2] = (0x80 | (code & 0x3f));
88+
return 3;
89+
}
90+
else if (code <= 0x10FFFF) {
91+
in[0] = (0xf0 | (code >> 18));
92+
in[1] = (0x80 | ((code >> 12) & 0x3f));
93+
in[2] = (0x80 | ((code >> 6) & 0x3f));
94+
in[3] = (0x80 | (code & 0x3f));
95+
return 4;
96+
}
97+
return -1;
98+
}
99+
100+
static int
101+
unicode_to_string(PyArrayMethod_Context *context, char *const data[],
102+
npy_intp const dimensions[], npy_intp const strides[],
103+
NpyAuxData *NPY_UNUSED(auxdata))
104+
{
105+
PyArray_Descr **descrs = context->descriptors;
106+
long in_size = (descrs[0]->elsize) / 4;
107+
108+
npy_intp N = dimensions[0];
109+
char *in = data[0];
110+
char **out = (char **)data[1];
111+
npy_intp in_stride = strides[0];
112+
// strides are in bytes but pointer offsets are in pointer widths, so
113+
// divide by the element size (one pointer width) to get the pointer offset
114+
npy_intp out_stride = strides[1] / context->descriptors[1]->elsize;
115+
116+
while (N--) {
117+
// pessimistically allocate 4 bytes per allowed character
118+
char *out_buf = calloc(in_size * 4 + 1, sizeof(char));
119+
size_t out_num_bytes = 0;
120+
for (int i = 0; i < in_size; i++) {
121+
// get code point
122+
Py_UCS4 code = ((Py_UCS4 *)in)[i];
123+
124+
if (code == 0) {
125+
break;
126+
}
127+
128+
// convert codepoint to UTF8 bytes
129+
char utf8_c[4] = {0};
130+
size_t num_bytes = ucs4_to_utf8_char(code, utf8_c);
131+
out_num_bytes += num_bytes;
132+
133+
if (num_bytes == -1) {
134+
// acquire GIL, set error, return
135+
PyGILState_STATE gstate;
136+
gstate = PyGILState_Ensure();
137+
PyErr_SetString(PyExc_TypeError,
138+
"Invalid unicode code point found");
139+
PyGILState_Release(gstate);
140+
return -1;
141+
}
142+
143+
// copy utf8_c into out_buf
144+
strncpy(out_buf, utf8_c, num_bytes);
145+
146+
// increment out_buf by the size of the character
147+
out_buf += num_bytes;
148+
}
149+
150+
// reset out_buf to the beginning of the string
151+
out_buf -= out_num_bytes;
152+
153+
// pad string with null character
154+
out_buf[out_num_bytes] = '\0';
155+
156+
// resize out_buf now that we know the real size
157+
out_buf = realloc(out_buf, out_num_bytes + 1);
158+
159+
// set out to the address of the beginning of the string
160+
out[0] = out_buf;
161+
162+
// increment out and in by strides
163+
in += in_stride;
164+
out += out_stride;
165+
}
166+
167+
return 0;
168+
}
169+
170+
static PyType_Slot u2s_slots[] = {
171+
{NPY_METH_resolve_descriptors, &string_resolve_descriptors},
172+
{NPY_METH_strided_loop, &unicode_to_string},
173+
{0, NULL}};
174+
175+
static char *u2s_name = "cast_Unicode_to_StringDType";
176+
67177
PyArrayMethod_Spec **
68178
get_casts(void)
69179
{
70-
PyArrayMethod_Spec **casts = malloc(2 * sizeof(PyArrayMethod_Spec *));
180+
PyArray_DTypeMeta **u2s_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
181+
u2s_dtypes[0] = &PyArray_UnicodeDType;
182+
u2s_dtypes[1] = NULL;
183+
184+
PyArrayMethod_Spec *UnicodeToStringCastSpec =
185+
malloc(sizeof(PyArrayMethod_Spec));
186+
187+
UnicodeToStringCastSpec->name = u2s_name;
188+
UnicodeToStringCastSpec->nin = 1;
189+
UnicodeToStringCastSpec->nout = 1;
190+
UnicodeToStringCastSpec->casting = NPY_SAFE_CASTING;
191+
UnicodeToStringCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
192+
UnicodeToStringCastSpec->dtypes = u2s_dtypes;
193+
UnicodeToStringCastSpec->slots = u2s_slots;
194+
195+
PyArrayMethod_Spec **casts = malloc(3 * sizeof(PyArrayMethod_Spec *));
71196
casts[0] = &StringToStringCastSpec;
72-
casts[1] = NULL;
197+
casts[1] = UnicodeToStringCastSpec;
198+
casts[2] = NULL;
73199

74200
return casts;
75201
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,9 @@ init_string_dtype(void)
226226

227227
StringDType.singleton = singleton;
228228

229+
free(StringDType_DTypeSpec.casts[1]->dtypes);
230+
free(StringDType_DTypeSpec.casts[1]);
231+
free(StringDType_DTypeSpec.casts);
232+
229233
return 0;
230234
}

stringdtype/tests/test_stringdtype.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,94 +6,99 @@
66

77
@pytest.fixture
88
def string_list():
9-
return ['abc', 'def', 'ghi']
9+
return ["abc", "def", "ghi"]
1010

1111

1212
def test_scalar_creation():
13-
assert str(StringScalar('abc', StringDType())) == 'abc'
13+
assert str(StringScalar("abc", StringDType())) == "abc"
1414

1515

1616
def test_dtype_creation():
17-
assert str(StringDType()) == 'StringDType'
17+
assert str(StringDType()) == "StringDType"
1818

1919

2020
@pytest.mark.parametrize(
21-
'data', [
22-
['abc', 'def', 'ghi'],
21+
"data",
22+
[
23+
["abc", "def", "ghi"],
2324
["🤣", "📵", "😰"],
2425
["🚜", "🙃", "😾"],
2526
["😹", "🚠", "🚌"],
26-
]
27+
],
2728
)
2829
def test_array_creation_utf8(data):
2930
arr = np.array(data, dtype=StringDType())
30-
assert repr(arr) == f'array({str(data)}, dtype=StringDType)'
31+
assert repr(arr) == f"array({str(data)}, dtype=StringDType)"
3132

3233

3334
def test_array_creation_scalars(string_list):
3435
dtype = StringDType()
3536
arr = np.array(
3637
[
37-
StringScalar('abc', dtype=dtype),
38-
StringScalar('def', dtype=dtype),
39-
StringScalar('ghi', dtype=dtype),
38+
StringScalar("abc", dtype=dtype),
39+
StringScalar("def", dtype=dtype),
40+
StringScalar("ghi", dtype=dtype),
4041
]
4142
)
4243
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
4344

4445

4546
@pytest.mark.parametrize(
46-
'data', [
47-
[1, 2, 3],
48-
[None, None, None],
49-
[b'abc', b'def', b'ghi'],
50-
[object, object, object],
51-
]
47+
"data",
48+
[
49+
[1, 2, 3],
50+
[None, None, None],
51+
[b"abc", b"def", b"ghi"],
52+
[object, object, object],
53+
],
5254
)
5355
def test_bad_scalars(data):
5456
with pytest.raises(TypeError):
5557
np.array(data, dtype=StringDType())
5658

5759

58-
@pytest.mark.xfail(reason='Not yet implemented')
60+
@pytest.mark.parametrize(
61+
("string_list"),
62+
[
63+
["this", "is", "an", "array"],
64+
["€", "", "😊"],
65+
],
66+
)
5967
def test_cast_to_stringdtype(string_list):
60-
arr = np.array(string_list, dtype='<U3').astype(StringDType())
68+
arr = np.array(string_list, dtype=np.unicode_).astype(StringDType())
6169
expected = np.array(string_list, dtype=StringDType())
6270
np.testing.assert_array_equal(arr, expected)
6371

6472

65-
@pytest.mark.xfail(reason='Not yet implemented')
73+
@pytest.mark.xfail(reason="Not yet implemented")
6674
def test_cast_to_unicode_safe(string_list):
6775
arr = np.array(string_list, dtype=StringDType())
6876

6977
np.testing.assert_array_equal(
70-
arr.astype('<U3', casting='safe'),
71-
np.array(string_list, dtype='<U3')
78+
arr.astype("<U3", casting="safe"), np.array(string_list, dtype="<U3")
7279
)
7380

7481
# Safe casting should preserve data
7582
with pytest.raises(TypeError):
76-
arr.astype('<U2', casting='safe')
83+
arr.astype("<U2", casting="safe")
7784

7885

79-
@pytest.mark.xfail(reason='Not yet implemented')
86+
@pytest.mark.xfail(reason="Not yet implemented")
8087
def test_cast_to_unicode_unsafe(string_list):
8188
arr = np.array(string_list, dtype=StringDType())
8289

8390
np.testing.assert_array_equal(
84-
arr.astype('<U3', casting='unsafe'),
85-
np.array(string_list, dtype='<U3')
91+
arr.astype("<U3", casting="unsafe"), np.array(string_list, dtype="<U3")
8692
)
8793

8894
# Unsafe casting: each element is truncated
8995
np.testing.assert_array_equal(
90-
arr.astype('<U2', casting='unsafe'),
91-
np.array(string_list, dtype='<U2')
96+
arr.astype("<U2", casting="unsafe"), np.array(string_list, dtype="<U2")
9297
)
9398

9499

95100
def test_insert_scalar(string_list):
96101
dtype = StringDType()
97102
arr = np.array(string_list, dtype=dtype)
98-
arr[1] = StringScalar('what', dtype=dtype)
99-
assert repr(arr) == repr(np.array(['abc', 'what', 'ghi'], dtype=dtype))
103+
arr[1] = StringScalar("what", dtype=dtype)
104+
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))

0 commit comments

Comments
 (0)