Skip to content

Commit 075392f

Browse files
committed
Add cast from string to bool
1 parent db6dcb2 commit 075392f

File tree

3 files changed

+92
-9
lines changed

3 files changed

+92
-9
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ static PyType_Slot s2s_slots[] = {
7777

7878
static char *s2s_name = "cast_StringDType_to_StringDType";
7979

80-
static PyArray_DTypeMeta *s2s_dtypes[] = {NULL, NULL};
81-
8280
// unicode to string
8381

8482
static NPY_CASTING
@@ -374,6 +372,66 @@ static PyType_Slot s2u_slots[] = {
374372

375373
static char *s2u_name = "cast_StringDType_to_Unicode";
376374

375+
// string to bool
376+
377+
static NPY_CASTING
378+
string_to_bool_resolve_descriptors(PyObject *NPY_UNUSED(self),
379+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
380+
PyArray_Descr *given_descrs[2],
381+
PyArray_Descr *loop_descrs[2],
382+
npy_intp *NPY_UNUSED(view_offset))
383+
{
384+
if (given_descrs[1] == NULL) {
385+
loop_descrs[1] = PyArray_DescrNewFromType(NPY_BOOL);
386+
}
387+
else {
388+
Py_INCREF(given_descrs[1]);
389+
loop_descrs[1] = given_descrs[1];
390+
}
391+
392+
Py_INCREF(given_descrs[0]);
393+
loop_descrs[0] = given_descrs[0];
394+
395+
return NPY_UNSAFE_CASTING;
396+
}
397+
398+
static int
399+
string_to_bool(PyArrayMethod_Context *context, char *const data[],
400+
npy_intp const dimensions[], npy_intp const strides[],
401+
NpyAuxData *NPY_UNUSED(auxdata))
402+
{
403+
npy_intp N = dimensions[0];
404+
char *in = data[0];
405+
char *out = data[1];
406+
407+
npy_intp in_stride = strides[0];
408+
npy_intp out_stride = strides[1];
409+
410+
ss *s = NULL;
411+
412+
while (N--) {
413+
load_string(in, &s);
414+
if (s->len == 0) {
415+
*out = (npy_bool)0;
416+
}
417+
else {
418+
*out = (npy_bool)1;
419+
}
420+
421+
in += in_stride;
422+
out += out_stride;
423+
}
424+
425+
return 0;
426+
}
427+
428+
static PyType_Slot s2b_slots[] = {
429+
{NPY_METH_resolve_descriptors, &string_to_bool_resolve_descriptors},
430+
{NPY_METH_strided_loop, &string_to_bool},
431+
{0, NULL}};
432+
433+
static char *s2b_name = "cast_StringDType_to_Bool";
434+
377435
PyArrayMethod_Spec *
378436
get_cast_spec(const char *name, NPY_CASTING casting,
379437
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
@@ -406,6 +464,8 @@ get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
406464
PyArrayMethod_Spec **
407465
get_casts(void)
408466
{
467+
PyArray_DTypeMeta **s2s_dtypes = get_dtypes(NULL, NULL);
468+
409469
PyArrayMethod_Spec *StringToStringCastSpec =
410470
get_cast_spec(s2s_name, NPY_NO_CASTING,
411471
NPY_METH_SUPPORTS_UNALIGNED, s2s_dtypes, s2s_slots);
@@ -422,11 +482,18 @@ get_casts(void)
422482
s2u_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
423483
s2u_dtypes, s2u_slots);
424484

425-
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
485+
PyArray_DTypeMeta **s2b_dtypes = get_dtypes(NULL, &PyArray_BoolDType);
486+
487+
PyArrayMethod_Spec *StringToBoolCastSpec = get_cast_spec(
488+
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
489+
s2b_dtypes, s2b_slots);
490+
491+
PyArrayMethod_Spec **casts = malloc(5 * sizeof(PyArrayMethod_Spec *));
426492
casts[0] = StringToStringCastSpec;
427493
casts[1] = UnicodeToStringCastSpec;
428494
casts[2] = StringToUnicodeCastSpec;
429-
casts[3] = NULL;
495+
casts[3] = StringToBoolCastSpec;
496+
casts[4] = NULL;
430497

431498
return casts;
432499
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,10 @@ init_string_dtype(void)
409409

410410
StringDType.singleton = singleton;
411411

412-
free(StringDType_DTypeSpec.casts[1]->dtypes);
413-
free(StringDType_DTypeSpec.casts[1]);
414-
free(StringDType_DTypeSpec.casts[2]->dtypes);
415-
free(StringDType_DTypeSpec.casts[2]);
416-
free(StringDType_DTypeSpec.casts);
412+
for (int i = 0; i < 4; i++) {
413+
free(StringDType_DTypeSpec.casts[i]->dtypes);
414+
free(StringDType_DTypeSpec.casts[i]);
415+
}
417416

418417
return 0;
419418
}

stringdtype/tests/test_stringdtype.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,20 @@ def test_arrfuncs_empty(arrfunc, expected):
242242
arr = np.empty(10, dtype=StringDType())
243243
result = arrfunc(arr)
244244
np.testing.assert_array_equal(result, expected, strict=True)
245+
246+
247+
@pytest.mark.parametrize(
248+
("string_list", "cast_answer", "any_answer", "all_answer"),
249+
[
250+
[["hello", "world"], [True, True], True, True],
251+
[["", ""], [False, False], False, False],
252+
[["hello", ""], [True, False], True, False],
253+
[["", "world"], [False, True], True, False],
254+
],
255+
)
256+
def test_bool_cast(string_list, cast_answer, any_answer, all_answer):
257+
sarr = np.array(string_list, dtype=StringDType())
258+
np.testing.assert_array_equal(sarr.astype("bool"), cast_answer)
259+
260+
assert np.any(sarr) == any_answer
261+
assert np.all(sarr) == all_answer

0 commit comments

Comments
 (0)