Skip to content

Commit b5b7ca2

Browse files
authored
Merge pull request #89 from ngoldbaum/self-cast-resolve-descriptors
make NA behavior for cast to string and cast to unicode match
2 parents 3240d93 + 591e1e1 commit b5b7ca2

File tree

2 files changed

+67
-4
lines changed

2 files changed

+67
-4
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,32 @@ string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
5353
Py_INCREF(given_descrs[0]);
5454
loop_descrs[0] = given_descrs[0];
5555

56+
StringDTypeObject *descr0 = (StringDTypeObject *)loop_descrs[0];
57+
StringDTypeObject *descr1 = (StringDTypeObject *)loop_descrs[1];
58+
59+
if ((descr0->na_object != NULL) && (descr1->na_object == NULL)) {
60+
// cast from a dtype with an NA to one without, so it's a lossy
61+
// unsafe cast
62+
return NPY_UNSAFE_CASTING;
63+
}
64+
5665
*view_offset = 0;
5766

5867
return NPY_NO_CASTING;
5968
}
6069

6170
static int
62-
string_to_string(PyArrayMethod_Context *NPY_UNUSED(context),
63-
char *const data[], npy_intp const dimensions[],
64-
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
71+
string_to_string(PyArrayMethod_Context *context, char *const data[],
72+
npy_intp const dimensions[], npy_intp const strides[],
73+
NpyAuxData *NPY_UNUSED(auxdata))
6574
{
75+
StringDTypeObject *in_descr =
76+
((StringDTypeObject *)context->descriptors[0]);
77+
StringDTypeObject *out_descr =
78+
((StringDTypeObject *)context->descriptors[1]);
79+
int in_hasnull = in_descr->na_object != NULL;
80+
int out_hasnull = out_descr->na_object != NULL;
81+
const npy_static_string *in_na_name = &in_descr->na_name;
6682
npy_intp N = dimensions[0];
6783
char *in = data[0];
6884
char *out = data[1];
@@ -74,7 +90,16 @@ string_to_string(PyArrayMethod_Context *NPY_UNUSED(context),
7490
npy_packed_static_string *os = (npy_packed_static_string *)out;
7591
if (in != out) {
7692
npy_string_free(os);
77-
if (npy_string_dup(s, os) < 0) {
93+
if (in_hasnull && !out_hasnull && npy_string_isnull(s)) {
94+
// lossy but this is an unsafe cast so this is OK
95+
if (npy_string_newsize(in_na_name->buf, in_na_name->size, os) <
96+
0) {
97+
gil_error(PyExc_MemoryError,
98+
"Failed to allocate string in string to string "
99+
"cast.");
100+
}
101+
}
102+
else if (npy_string_dup(s, os) < 0) {
78103
gil_error(PyExc_MemoryError, "npy_string_dup failed");
79104
return -1;
80105
}

stringdtype/tests/test_stringdtype.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def dtype(na_object, coerce):
4848
return StringDType(coerce=coerce)
4949

5050

51+
# second copy for cast tests to do a cartesian product over dtypes
52+
@pytest.fixture()
53+
def dtype2(na_object, coerce):
54+
# explicit is check for pd_NA because != with pd_NA returns pd_NA
55+
if na_object is pd_NA or na_object != "unset":
56+
return StringDType(na_object=na_object, coerce=coerce)
57+
else:
58+
return StringDType(coerce=coerce)
59+
60+
5161
def test_dtype_creation():
5262
hashes = set()
5363
dt = StringDType()
@@ -136,6 +146,34 @@ def test_scalars_string_conversion(data, dtype):
136146
np.array(data, dtype=dtype)
137147

138148

149+
@pytest.mark.parametrize(
150+
("strings"),
151+
[
152+
["this", "is", "an", "array"],
153+
["€", "", "😊"],
154+
["A¢☃€ 😊", " A☃€¢😊", "☃€😊 A¢", "😊☃A¢ €"],
155+
],
156+
)
157+
def test_self_casts(dtype, dtype2, strings):
158+
if hasattr(dtype, "na_object"):
159+
strings = strings + [dtype.na_object]
160+
arr = np.array(strings, dtype=dtype)
161+
newarr = arr.astype(dtype2)
162+
163+
if hasattr(dtype, "na_object") and not hasattr(dtype2, "na_object"):
164+
assert newarr[-1] == str(dtype.na_object)
165+
with pytest.raises(TypeError):
166+
arr.astype(dtype2, casting="safe")
167+
arr.astype(dtype2, casting="unsafe")
168+
elif hasattr(dtype, "na_object") and hasattr(dtype2, "na_object"):
169+
assert newarr[-1] is dtype2.na_object
170+
arr.astype(dtype2, casting="safe")
171+
else:
172+
arr.astype(dtype2, casting="safe")
173+
174+
np.testing.assert_array_equal(arr[:-1], newarr[:-1])
175+
176+
139177
@pytest.mark.parametrize(
140178
("strings"),
141179
[

0 commit comments

Comments
 (0)