Skip to content

Commit 5659778

Browse files
authored
Merge pull request #26 from ngoldbaum/char-tests
Add test for functions in np.char
2 parents 07c5087 + b5d51dd commit 5659778

File tree

3 files changed

+140
-7
lines changed

3 files changed

+140
-7
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,26 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
77
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
88
PyArray_Descr *given_descrs[2],
99
PyArray_Descr *loop_descrs[2],
10-
npy_intp *NPY_UNUSED(view_offset))
10+
npy_intp *view_offset)
1111
{
12-
Py_INCREF(given_descrs[0]);
13-
loop_descrs[0] = given_descrs[0];
14-
1512
if (given_descrs[1] == NULL) {
16-
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance();
13+
StringDTypeObject *new = new_stringdtype_instance();
14+
if (new == NULL) {
15+
return (NPY_CASTING)-1;
16+
}
17+
loop_descrs[1] = (PyArray_Descr *)new;
1718
}
1819
else {
1920
Py_INCREF(given_descrs[1]);
2021
loop_descrs[1] = given_descrs[1];
2122
}
2223

23-
return NPY_SAFE_CASTING;
24+
Py_INCREF(given_descrs[0]);
25+
loop_descrs[0] = given_descrs[0];
26+
27+
*view_offset = 0;
28+
29+
return NPY_NO_CASTING;
2430
}
2531

2632
static int
@@ -59,7 +65,7 @@ PyArrayMethod_Spec StringToStringCastSpec = {
5965
.name = "cast_StringDType_to_StringDType",
6066
.nin = 1,
6167
.nout = 1,
62-
.casting = NPY_UNSAFE_CASTING,
68+
.casting = NPY_NO_CASTING,
6369
.flags = NPY_METH_SUPPORTS_UNALIGNED,
6470
.dtypes = s2s_dtypes,
6571
.slots = s2s_slots,

stringdtype/tests/test_char.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import numpy as np
2+
import pytest
3+
from numpy.testing import assert_array_equal
4+
5+
from stringdtype import StringDType
6+
7+
TEST_DATA = ["hello", "Ae¢☃€ 😊", "entry\nwith\nnewlines", "entry\twith\ttabs"]
8+
9+
10+
@pytest.fixture
11+
def string_array():
12+
return np.array(TEST_DATA, dtype=StringDType())
13+
14+
15+
@pytest.fixture
16+
def unicode_array():
17+
return np.array(TEST_DATA, dtype=np.unicode_)
18+
19+
20+
UNARY_FUNCTIONS = [
21+
"str_len",
22+
"capitalize",
23+
"expandtabs",
24+
"isalnum",
25+
"isalpha",
26+
"isdigit",
27+
"islower",
28+
"isspace",
29+
"istitle",
30+
"isupper",
31+
"lower",
32+
"splitlines",
33+
"swapcase",
34+
"title",
35+
"upper",
36+
"isnumeric",
37+
"isdecimal",
38+
]
39+
40+
41+
@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS)
42+
def test_unary(string_array, unicode_array, function_name):
43+
func = getattr(np.char, function_name)
44+
sres = func(string_array)
45+
ures = func(unicode_array)
46+
if sres.dtype == StringDType():
47+
ures = ures.astype(StringDType())
48+
assert_array_equal(sres, ures)
49+
50+
51+
# None means that the argument is a string array
52+
BINARY_FUNCTIONS = [
53+
("add", (None, None)),
54+
("multiply", (None, 2)),
55+
("mod", ("format: %s", None)),
56+
("center", (None, 25)),
57+
("count", (None, "A")),
58+
("encode", (None, "UTF-8")),
59+
("endswith", (None, "lo")),
60+
("find", (None, "A")),
61+
("index", (None, "e")),
62+
("join", ("-", None)),
63+
("ljust", (None, 12)),
64+
("partition", (None, "A")),
65+
("replace", (None, "A", "B")),
66+
("rfind", (None, "A")),
67+
("rindex", (None, "e")),
68+
("rjust", (None, 12)),
69+
("rpartition", (None, "A")),
70+
("split", (None, "A")),
71+
("startswith", (None, "A")),
72+
("zfill", (None, 12)),
73+
]
74+
75+
76+
@pytest.mark.parametrize("function_name, args", BINARY_FUNCTIONS)
77+
def test_binary(string_array, unicode_array, function_name, args):
78+
func = getattr(np.char, function_name)
79+
if args == (None, None):
80+
sres = func(string_array, string_array)
81+
ures = func(unicode_array, unicode_array)
82+
elif args[0] is None:
83+
sres = func(string_array, *args[1:])
84+
ures = func(string_array, *args[1:])
85+
elif args[1] is None:
86+
sres = func(args[0], string_array)
87+
ures = func(args[0], string_array)
88+
else:
89+
# shouldn't ever happen
90+
raise RuntimeError
91+
if sres.dtype == StringDType():
92+
ures = ures.astype(StringDType())
93+
assert_array_equal(sres, ures)
94+
95+
96+
def test_strip(string_array, unicode_array):
97+
rjs = np.char.rjust(string_array, 25)
98+
rju = np.char.rjust(unicode_array, 25)
99+
100+
ljs = np.char.ljust(string_array, 25)
101+
lju = np.char.ljust(unicode_array, 25)
102+
103+
assert_array_equal(
104+
np.char.lstrip(rjs),
105+
np.char.lstrip(rju).astype(StringDType()),
106+
)
107+
108+
assert_array_equal(
109+
np.char.rstrip(ljs),
110+
np.char.rstrip(lju).astype(StringDType()),
111+
)
112+
113+
assert_array_equal(
114+
np.char.strip(ljs),
115+
np.char.strip(lju).astype(StringDType()),
116+
)
117+
118+
assert_array_equal(
119+
np.char.strip(rjs),
120+
np.char.strip(rju).astype(StringDType()),
121+
)

stringdtype/tests/test_stringdtype.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def test_dtype_creation():
1717
assert str(StringDType()) == "StringDType"
1818

1919

20+
def test_dtype_equality():
21+
assert StringDType() == StringDType()
22+
assert StringDType() != np.dtype("U")
23+
assert StringDType() != np.dtype("U8")
24+
25+
2026
@pytest.mark.parametrize(
2127
"data",
2228
[

0 commit comments

Comments
 (0)