Skip to content

Commit a5e9efc

Browse files
authored
Merge pull request #46 from ngoldbaum/cast-to-bool
Add a cast from string to bool
2 parents 0bd2328 + 82cea28 commit a5e9efc

File tree

3 files changed

+145
-49
lines changed

3 files changed

+145
-49
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 124 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ gil_error(PyObject *type, const char *msg)
1212
PyGILState_Release(gstate);
1313
}
1414

15+
// string to string
16+
1517
static NPY_CASTING
1618
string_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
1719
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -67,23 +69,15 @@ string_to_string(PyArrayMethod_Context *NPY_UNUSED(context),
6769
return 0;
6870
}
6971

70-
static PyArray_DTypeMeta *s2s_dtypes[2] = {NULL, NULL};
71-
7272
static PyType_Slot s2s_slots[] = {
7373
{NPY_METH_resolve_descriptors, &string_to_string_resolve_descriptors},
7474
{NPY_METH_strided_loop, &string_to_string},
7575
{NPY_METH_unaligned_strided_loop, &string_to_string},
7676
{0, NULL}};
7777

78-
PyArrayMethod_Spec StringToStringCastSpec = {
79-
.name = "cast_StringDType_to_StringDType",
80-
.nin = 1,
81-
.nout = 1,
82-
.casting = NPY_NO_CASTING,
83-
.flags = NPY_METH_SUPPORTS_UNALIGNED,
84-
.dtypes = s2s_dtypes,
85-
.slots = s2s_slots,
86-
};
78+
static char *s2s_name = "cast_StringDType_to_StringDType";
79+
80+
// unicode to string
8781

8882
static NPY_CASTING
8983
unicode_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -261,6 +255,8 @@ static PyType_Slot u2s_slots[] = {
261255

262256
static char *u2s_name = "cast_Unicode_to_StringDType";
263257

258+
// string to unicode
259+
264260
static NPY_CASTING
265261
string_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self),
266262
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
@@ -376,44 +372,128 @@ static PyType_Slot s2u_slots[] = {
376372

377373
static char *s2u_name = "cast_StringDType_to_Unicode";
378374

375+
// string to bool
376+
377+
static NPY_CASTING
378+
string_to_bool_resolve_descriptors(PyObject *NPY_UNUSED(self),
379+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
380+
PyArray_Descr *given_descrs[2],
381+
PyArray_Descr *loop_descrs[2],
382+
npy_intp *NPY_UNUSED(view_offset))
383+
{
384+
if (given_descrs[1] == NULL) {
385+
loop_descrs[1] = PyArray_DescrNewFromType(NPY_BOOL);
386+
}
387+
else {
388+
Py_INCREF(given_descrs[1]);
389+
loop_descrs[1] = given_descrs[1];
390+
}
391+
392+
Py_INCREF(given_descrs[0]);
393+
loop_descrs[0] = given_descrs[0];
394+
395+
return NPY_UNSAFE_CASTING;
396+
}
397+
398+
static int
399+
string_to_bool(PyArrayMethod_Context *context, char *const data[],
400+
npy_intp const dimensions[], npy_intp const strides[],
401+
NpyAuxData *NPY_UNUSED(auxdata))
402+
{
403+
npy_intp N = dimensions[0];
404+
char *in = data[0];
405+
char *out = data[1];
406+
407+
npy_intp in_stride = strides[0];
408+
npy_intp out_stride = strides[1];
409+
410+
ss *s = NULL;
411+
412+
while (N--) {
413+
load_string(in, &s);
414+
if (s->len == 0) {
415+
*out = (npy_bool)0;
416+
}
417+
else {
418+
*out = (npy_bool)1;
419+
}
420+
421+
in += in_stride;
422+
out += out_stride;
423+
}
424+
425+
return 0;
426+
}
427+
428+
static PyType_Slot s2b_slots[] = {
429+
{NPY_METH_resolve_descriptors, &string_to_bool_resolve_descriptors},
430+
{NPY_METH_strided_loop, &string_to_bool},
431+
{0, NULL}};
432+
433+
static char *s2b_name = "cast_StringDType_to_Bool";
434+
435+
PyArrayMethod_Spec *
436+
get_cast_spec(const char *name, NPY_CASTING casting,
437+
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
438+
PyType_Slot *slots)
439+
{
440+
PyArrayMethod_Spec *ret = malloc(sizeof(PyArrayMethod_Spec));
441+
442+
ret->name = name;
443+
ret->nin = 1;
444+
ret->nout = 1;
445+
ret->casting = casting;
446+
ret->flags = flags;
447+
ret->dtypes = dtypes;
448+
ret->slots = slots;
449+
450+
return ret;
451+
}
452+
453+
PyArray_DTypeMeta **
454+
get_dtypes(PyArray_DTypeMeta *dt1, PyArray_DTypeMeta *dt2)
455+
{
456+
PyArray_DTypeMeta **ret = malloc(2 * sizeof(PyArray_DTypeMeta *));
457+
458+
ret[0] = dt1;
459+
ret[1] = dt2;
460+
461+
return ret;
462+
}
463+
379464
PyArrayMethod_Spec **
380465
get_casts(void)
381466
{
382-
PyArray_DTypeMeta **u2s_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
383-
u2s_dtypes[0] = &PyArray_UnicodeDType;
384-
u2s_dtypes[1] = NULL;
385-
386-
PyArrayMethod_Spec *UnicodeToStringCastSpec =
387-
malloc(sizeof(PyArrayMethod_Spec));
388-
389-
UnicodeToStringCastSpec->name = u2s_name;
390-
UnicodeToStringCastSpec->nin = 1;
391-
UnicodeToStringCastSpec->nout = 1;
392-
UnicodeToStringCastSpec->casting = NPY_SAFE_CASTING;
393-
UnicodeToStringCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
394-
UnicodeToStringCastSpec->dtypes = u2s_dtypes;
395-
UnicodeToStringCastSpec->slots = u2s_slots;
396-
397-
PyArray_DTypeMeta **s2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *));
398-
s2u_dtypes[0] = NULL;
399-
s2u_dtypes[1] = &PyArray_UnicodeDType;
400-
401-
PyArrayMethod_Spec *StringToUnicodeCastSpec =
402-
malloc(sizeof(PyArrayMethod_Spec));
403-
404-
StringToUnicodeCastSpec->name = s2u_name;
405-
StringToUnicodeCastSpec->nin = 1;
406-
StringToUnicodeCastSpec->nout = 1;
407-
StringToUnicodeCastSpec->casting = NPY_SAFE_CASTING;
408-
StringToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
409-
StringToUnicodeCastSpec->dtypes = s2u_dtypes;
410-
StringToUnicodeCastSpec->slots = s2u_slots;
411-
412-
PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *));
413-
casts[0] = &StringToStringCastSpec;
467+
PyArray_DTypeMeta **s2s_dtypes = get_dtypes(NULL, NULL);
468+
469+
PyArrayMethod_Spec *StringToStringCastSpec =
470+
get_cast_spec(s2s_name, NPY_NO_CASTING,
471+
NPY_METH_SUPPORTS_UNALIGNED, s2s_dtypes, s2s_slots);
472+
473+
PyArray_DTypeMeta **u2s_dtypes = get_dtypes(&PyArray_UnicodeDType, NULL);
474+
475+
PyArrayMethod_Spec *UnicodeToStringCastSpec = get_cast_spec(
476+
u2s_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
477+
u2s_dtypes, u2s_slots);
478+
479+
PyArray_DTypeMeta **s2u_dtypes = get_dtypes(NULL, &PyArray_UnicodeDType);
480+
481+
PyArrayMethod_Spec *StringToUnicodeCastSpec = get_cast_spec(
482+
s2u_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
483+
s2u_dtypes, s2u_slots);
484+
485+
PyArray_DTypeMeta **s2b_dtypes = get_dtypes(NULL, &PyArray_BoolDType);
486+
487+
PyArrayMethod_Spec *StringToBoolCastSpec = get_cast_spec(
488+
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
489+
s2b_dtypes, s2b_slots);
490+
491+
PyArrayMethod_Spec **casts = malloc(5 * sizeof(PyArrayMethod_Spec *));
492+
casts[0] = StringToStringCastSpec;
414493
casts[1] = UnicodeToStringCastSpec;
415494
casts[2] = StringToUnicodeCastSpec;
416-
casts[3] = NULL;
495+
casts[3] = StringToBoolCastSpec;
496+
casts[4] = NULL;
417497

418498
return casts;
419499
}

stringdtype/stringdtype/src/dtype.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,10 @@ init_string_dtype(void)
409409

410410
StringDType.singleton = singleton;
411411

412-
free(StringDType_DTypeSpec.casts[1]->dtypes);
413-
free(StringDType_DTypeSpec.casts[1]);
414-
free(StringDType_DTypeSpec.casts[2]->dtypes);
415-
free(StringDType_DTypeSpec.casts[2]);
416-
free(StringDType_DTypeSpec.casts);
412+
for (int i = 0; casts[i] != NULL; i++) {
413+
free(casts[i]->dtypes);
414+
free(casts[i]);
415+
}
417416

418417
return 0;
419418
}

stringdtype/tests/test_stringdtype.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,20 @@ def test_arrfuncs_empty(arrfunc, expected):
242242
arr = np.empty(10, dtype=StringDType())
243243
result = arrfunc(arr)
244244
np.testing.assert_array_equal(result, expected, strict=True)
245+
246+
247+
@pytest.mark.parametrize(
248+
("string_list", "cast_answer", "any_answer", "all_answer"),
249+
[
250+
[["hello", "world"], [True, True], True, True],
251+
[["", ""], [False, False], False, False],
252+
[["hello", ""], [True, False], True, False],
253+
[["", "world"], [False, True], True, False],
254+
],
255+
)
256+
def test_bool_cast(string_list, cast_answer, any_answer, all_answer):
257+
sarr = np.array(string_list, dtype=StringDType())
258+
np.testing.assert_array_equal(sarr.astype("bool"), cast_answer)
259+
260+
assert np.any(sarr) == any_answer
261+
assert np.all(sarr) == all_answer

0 commit comments

Comments
 (0)