Skip to content

Commit 0dabe4e

Browse files
authored
Merge pull request #24 from numpy/stringdtype-tests
Add tests for string dtype
2 parents 5ad9a75 + 971a01f commit 0dabe4e

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,12 @@ jobs:
6464
working-directory: quaddtype
6565
run: |
6666
pytest -vvv --color=yes
67+
- name: Install stringdtype
68+
working-directory: stringdtype
69+
run: |
70+
python -m build --no-isolation --wheel -Cbuilddir=build
71+
find ./dist/*.whl | xargs python -m pip install
72+
- name: Run stringdtype tests
73+
working-directory: stringdtype
74+
run: |
75+
pytest -vvv --color=yes

stringdtype/stringdtype/src/dtype.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ static int
9393
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
9494
{
9595
PyObject *val_obj = get_value(obj);
96+
if (val_obj == NULL) {
97+
return -1;
98+
}
99+
96100
char *val = NULL;
97101
Py_ssize_t length = 0;
98102
if (PyBytes_AsStringAndSize(val_obj, &val, &length) == -1) {

stringdtype/tests/test_stringdtype.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import numpy as np
2+
import pytest
3+
4+
from stringdtype import StringDType, StringScalar
5+
6+
7+
@pytest.fixture
8+
def string_list():
9+
return ['abc', 'def', 'ghi']
10+
11+
12+
def test_scalar_creation():
13+
assert str(StringScalar('abc', StringDType())) == 'abc'
14+
15+
16+
def test_dtype_creation():
17+
assert str(StringDType()) == 'StringDType'
18+
19+
20+
@pytest.mark.parametrize(
21+
'data', [
22+
['abc', 'def', 'ghi'],
23+
["🤣", "📵", "😰"],
24+
["🚜", "🙃", "😾"],
25+
["😹", "🚠", "🚌"],
26+
]
27+
)
28+
def test_array_creation_utf8(data):
29+
arr = np.array(data, dtype=StringDType())
30+
assert repr(arr) == f'array({str(data)}, dtype=StringDType)'
31+
32+
33+
def test_array_creation_scalars(string_list):
34+
dtype = StringDType()
35+
arr = np.array(
36+
[
37+
StringScalar('abc', dtype=dtype),
38+
StringScalar('def', dtype=dtype),
39+
StringScalar('ghi', dtype=dtype),
40+
]
41+
)
42+
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
43+
44+
45+
@pytest.mark.parametrize(
46+
'data', [
47+
[1, 2, 3],
48+
[None, None, None],
49+
[b'abc', b'def', b'ghi'],
50+
[object, object, object],
51+
]
52+
)
53+
def test_bad_scalars(data):
54+
with pytest.raises(TypeError):
55+
np.array(data, dtype=StringDType())
56+
57+
58+
@pytest.mark.xfail(reason='Not yet implemented')
59+
def test_cast_to_stringdtype(string_list):
60+
arr = np.array(string_list, dtype='<U3').astype(StringDType())
61+
expected = np.array(string_list, dtype=StringDType())
62+
np.testing.assert_array_equal(arr, expected)
63+
64+
65+
@pytest.mark.xfail(reason='Not yet implemented')
66+
def test_cast_to_unicode_safe(string_list):
67+
arr = np.array(string_list, dtype=StringDType())
68+
69+
np.testing.assert_array_equal(
70+
arr.astype('<U3', casting='safe'),
71+
np.array(string_list, dtype='<U3')
72+
)
73+
74+
# Safe casting should preserve data
75+
with pytest.raises(TypeError):
76+
arr.astype('<U2', casting='safe')
77+
78+
79+
@pytest.mark.xfail(reason='Not yet implemented')
80+
def test_cast_to_unicode_unsafe(string_list):
81+
arr = np.array(string_list, dtype=StringDType())
82+
83+
np.testing.assert_array_equal(
84+
arr.astype('<U3', casting='unsafe'),
85+
np.array(string_list, dtype='<U3')
86+
)
87+
88+
# Unsafe casting: each element is truncated
89+
np.testing.assert_array_equal(
90+
arr.astype('<U2', casting='unsafe'),
91+
np.array(string_list, dtype='<U2')
92+
)
93+
94+
95+
def test_insert_scalar(string_list):
96+
dtype = StringDType()
97+
arr = np.array(string_list, dtype=dtype)
98+
arr[1] = StringScalar('what', dtype=dtype)
99+
assert repr(arr) == repr(np.array(['abc', 'what', 'ghi'], dtype=dtype))

0 commit comments

Comments
 (0)