Skip to content

Commit 283baab

Browse files
committed
Add a bool to string cast
1 parent b9c9cac commit 283baab

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,80 @@ static PyType_Slot s2b_slots[] = {
449449

450450
static char *s2b_name = "cast_StringDType_to_Bool";
451451

452+
// bool to string
453+
454+
static NPY_CASTING
455+
bool_to_string_resolve_descriptors(PyObject *NPY_UNUSED(self),
456+
PyArray_DTypeMeta *dtypes[2],
457+
PyArray_Descr *given_descrs[2],
458+
PyArray_Descr *loop_descrs[2],
459+
npy_intp *NPY_UNUSED(view_offset))
460+
{
461+
if (given_descrs[1] == NULL) {
462+
PyArray_Descr *new = (PyArray_Descr *)new_stringdtype_instance(
463+
(PyTypeObject *)dtypes[1]);
464+
if (new == NULL) {
465+
return (NPY_CASTING)-1;
466+
}
467+
loop_descrs[1] = new;
468+
}
469+
else {
470+
Py_INCREF(given_descrs[1]);
471+
loop_descrs[1] = given_descrs[1];
472+
}
473+
474+
Py_INCREF(given_descrs[0]);
475+
loop_descrs[0] = given_descrs[0];
476+
477+
return NPY_SAFE_CASTING;
478+
}
479+
480+
static int
481+
bool_to_string(PyArrayMethod_Context *NPY_UNUSED(context), char *const data[],
482+
npy_intp const dimensions[], npy_intp const strides[],
483+
NpyAuxData *NPY_UNUSED(auxdata))
484+
{
485+
npy_intp N = dimensions[0];
486+
char *in = data[0];
487+
char *out = data[1];
488+
489+
npy_intp in_stride = strides[0];
490+
npy_intp out_stride = strides[1];
491+
492+
while (N--) {
493+
ss *out_ss = (ss *)out;
494+
ssfree(out_ss);
495+
if ((npy_bool)(*in) == 1) {
496+
if (ssnewlen("True", 4, out_ss) < 0) {
497+
gil_error(PyExc_MemoryError, "ssnewlen failed");
498+
return -1;
499+
}
500+
}
501+
else if ((npy_bool)(*in) == 0) {
502+
if (ssnewlen("False", 5, out_ss) < 0) {
503+
gil_error(PyExc_MemoryError, "ssnewlen failed");
504+
return -1;
505+
}
506+
}
507+
else {
508+
gil_error(PyExc_RuntimeError,
509+
"invalid value encountered in bool to string cast");
510+
return -1;
511+
}
512+
in += in_stride;
513+
out += out_stride;
514+
}
515+
516+
return 0;
517+
}
518+
519+
static PyType_Slot b2s_slots[] = {
520+
{NPY_METH_resolve_descriptors, &bool_to_string_resolve_descriptors},
521+
{NPY_METH_strided_loop, &bool_to_string},
522+
{0, NULL}};
523+
524+
static char *b2s_name = "cast_Bool_to_StringDType";
525+
452526
PyArrayMethod_Spec *
453527
get_cast_spec(const char *name, NPY_CASTING casting,
454528
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
@@ -501,10 +575,10 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
501575

502576
int is_pandas = (this == (PyArray_DTypeMeta *)&PandasStringDType);
503577

504-
int num_casts = 5;
578+
int num_casts = 6;
505579

506580
if (is_pandas) {
507-
num_casts = 7;
581+
num_casts += 2;
508582

509583
PyArray_DTypeMeta **t2o_dtypes = get_dtypes(this, other);
510584

@@ -537,6 +611,12 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
537611
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
538612
s2b_dtypes, s2b_slots);
539613

614+
PyArray_DTypeMeta **b2s_dtypes = get_dtypes(&PyArray_BoolDType, this);
615+
616+
PyArrayMethod_Spec *BoolToStringCastSpec = get_cast_spec(
617+
b2s_name, NPY_SAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
618+
b2s_dtypes, b2s_slots);
619+
540620
PyArrayMethod_Spec **casts = NULL;
541621

542622
casts = malloc(num_casts * sizeof(PyArrayMethod_Spec *));
@@ -545,13 +625,14 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
545625
casts[1] = UnicodeToStringCastSpec;
546626
casts[2] = StringToUnicodeCastSpec;
547627
casts[3] = StringToBoolCastSpec;
628+
casts[4] = BoolToStringCastSpec;
548629
if (is_pandas) {
549-
casts[4] = ThisToOtherCastSpec;
550-
casts[5] = OtherToThisCastSpec;
551-
casts[6] = NULL;
630+
casts[5] = ThisToOtherCastSpec;
631+
casts[6] = OtherToThisCastSpec;
632+
casts[7] = NULL;
552633
}
553634
else {
554-
casts[4] = NULL;
635+
casts[5] = NULL;
555636
}
556637

557638
return casts;

stringdtype/tests/test_stringdtype.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,30 @@ def test_arrfuncs_zeros(dtype, arrfunc, expected):
350350
[["", "world"], [False, True], True, False],
351351
],
352352
)
353-
def test_bool_cast(dtype, strings, cast_answer, any_answer, all_answer):
353+
def test_cast_to_bool(dtype, strings, cast_answer, any_answer, all_answer):
354354
sarr = np.array(strings, dtype=dtype)
355355
np.testing.assert_array_equal(sarr.astype("bool"), cast_answer)
356356

357357
assert np.any(sarr) == any_answer
358358
assert np.all(sarr) == all_answer
359359

360360

361+
@pytest.mark.parametrize(
362+
("strings", "cast_answer"),
363+
[
364+
[[True, True], ["True", "True"]],
365+
[[False, False], ["False", "False"]],
366+
[[True, False], ["True", "False"]],
367+
[[False, True], ["False", "True"]],
368+
],
369+
)
370+
def test_cast_from_bool(dtype, strings, cast_answer):
371+
barr = np.array(strings, dtype=bool)
372+
np.testing.assert_array_equal(
373+
barr.astype(dtype), np.array(cast_answer, dtype=dtype)
374+
)
375+
376+
361377
def test_take(dtype, string_list):
362378
sarr = np.array(string_list, dtype=dtype)
363379
out = np.empty(len(string_list), dtype=dtype)

0 commit comments

Comments
 (0)