Skip to content

Commit 5668e12

Browse files
committed
Add a safe cast from object to string
1 parent 1ce658b commit 5668e12

File tree

3 files changed

+135
-7
lines changed

3 files changed

+135
-7
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,124 @@ static PyType_Slot s2b_slots[] = {
449449

450450
static char *s2b_name = "cast_StringDType_to_Bool";
451451

452+
// object to string
453+
454+
typedef struct {
455+
NpyAuxData base;
456+
PyArray_Descr *descr;
457+
int move_references;
458+
} _object_to_string_auxdata;
459+
460+
static void
461+
_object_to_string_auxdata_free(NpyAuxData *auxdata)
462+
{
463+
_object_to_string_auxdata *data = (_object_to_string_auxdata *)auxdata;
464+
Py_DECREF(data->descr);
465+
PyMem_Free(data);
466+
}
467+
468+
static NpyAuxData *
469+
_object_to_string_auxdata_clone(NpyAuxData *data)
470+
{
471+
_object_to_string_auxdata *res = PyMem_Malloc(sizeof(*res));
472+
if (res == NULL) {
473+
return NULL;
474+
}
475+
memcpy(res, data, sizeof(*res));
476+
Py_INCREF(res->descr);
477+
return (NpyAuxData *)res;
478+
}
479+
480+
static int
481+
object_to_string_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
482+
char *const *args, const npy_intp *dimensions,
483+
const npy_intp *strides, NpyAuxData *auxdata)
484+
{
485+
npy_intp N = dimensions[0];
486+
char *src = args[0], *dst = args[1];
487+
npy_intp src_stride = strides[0], dst_stride = strides[1];
488+
_object_to_string_auxdata *data = (_object_to_string_auxdata *)auxdata;
489+
490+
PyObject *src_ref;
491+
492+
while (N > 0) {
493+
memcpy(&src_ref, src, sizeof(src_ref));
494+
if (stringdtype_setitem((StringDTypeObject *)(data->descr),
495+
src_ref ? src_ref : Py_None,
496+
(void *)dst) < 0) {
497+
return -1;
498+
}
499+
500+
if (data->move_references && src_ref != NULL) {
501+
Py_DECREF(src_ref);
502+
memset(src, 0, sizeof(src_ref));
503+
}
504+
505+
N--;
506+
dst += dst_stride;
507+
src += src_stride;
508+
}
509+
return 0;
510+
}
511+
512+
NPY_NO_EXPORT int
513+
object_to_string_get_loop(PyArrayMethod_Context *context,
514+
int NPY_UNUSED(aligned), int move_references,
515+
const npy_intp *NPY_UNUSED(strides),
516+
PyArrayMethod_StridedLoop **out_loop,
517+
NpyAuxData **out_transferdata,
518+
NPY_ARRAYMETHOD_FLAGS *flags)
519+
{
520+
*flags = NPY_METH_REQUIRES_PYAPI;
521+
522+
/* NOTE: auxdata is only really necessary to flag `move_references` */
523+
_object_to_string_auxdata *data = PyMem_Malloc(sizeof(*data));
524+
if (data == NULL) {
525+
return -1;
526+
}
527+
data->base.free = &_object_to_string_auxdata_free;
528+
data->base.clone = &_object_to_string_auxdata_clone;
529+
530+
Py_INCREF(context->descriptors[1]);
531+
data->descr = context->descriptors[1];
532+
data->move_references = move_references;
533+
*out_transferdata = (NpyAuxData *)data;
534+
*out_loop = &object_to_string_strided_loop;
535+
return 0;
536+
}
537+
538+
static NPY_CASTING
539+
object_to_string_resolve_descriptors(PyArrayMethodObject *NPY_UNUSED(self),
540+
PyArray_DTypeMeta *dtypes[2],
541+
PyArray_Descr *given_descrs[2],
542+
PyArray_Descr *loop_descrs[2],
543+
npy_intp *NPY_UNUSED(view_offset))
544+
{
545+
if (given_descrs[1] == NULL) {
546+
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance(
547+
(PyTypeObject *)dtypes[1]);
548+
if (loop_descrs[1] == NULL) {
549+
return -1;
550+
}
551+
}
552+
else {
553+
Py_INCREF(given_descrs[1]);
554+
loop_descrs[1] = given_descrs[1];
555+
}
556+
557+
Py_INCREF(given_descrs[0]);
558+
loop_descrs[0] = given_descrs[0];
559+
560+
return NPY_SAFE_CASTING;
561+
}
562+
563+
static PyType_Slot o2s_slots[] = {
564+
{NPY_METH_resolve_descriptors, &object_to_string_resolve_descriptors},
565+
{_NPY_METH_get_loop, &object_to_string_get_loop},
566+
{0, NULL}};
567+
568+
static char *o2s_name = "cast_object_to_StringDType";
569+
452570
PyArrayMethod_Spec *
453571
get_cast_spec(const char *name, NPY_CASTING casting,
454572
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
@@ -501,10 +619,10 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
501619

502620
int is_pandas = (this == (PyArray_DTypeMeta *)&PandasStringDType);
503621

504-
int num_casts = 5;
622+
int num_casts = 6;
505623

506624
if (is_pandas) {
507-
num_casts = 7;
625+
num_casts = 8;
508626

509627
PyArray_DTypeMeta **t2o_dtypes = get_dtypes(this, other);
510628

@@ -537,6 +655,12 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
537655
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
538656
s2b_dtypes, s2b_slots);
539657

658+
PyArray_DTypeMeta **o2s_dtypes = get_dtypes(&PyArray_ObjectDType, this);
659+
660+
PyArrayMethod_Spec *ObjectToStringCastSpec =
661+
get_cast_spec(o2s_name, NPY_SAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
662+
o2s_dtypes, o2s_slots);
663+
540664
PyArrayMethod_Spec **casts = NULL;
541665

542666
casts = malloc(num_casts * sizeof(PyArrayMethod_Spec *));
@@ -545,13 +669,14 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
545669
casts[1] = UnicodeToStringCastSpec;
546670
casts[2] = StringToUnicodeCastSpec;
547671
casts[3] = StringToBoolCastSpec;
672+
casts[4] = ObjectToStringCastSpec;
548673
if (is_pandas) {
549-
casts[4] = ThisToOtherCastSpec;
550-
casts[5] = OtherToThisCastSpec;
551-
casts[6] = NULL;
674+
casts[5] = ThisToOtherCastSpec;
675+
casts[6] = OtherToThisCastSpec;
676+
casts[7] = NULL;
552677
}
553678
else {
554-
casts[4] = NULL;
679+
casts[5] = NULL;
555680
}
556681

557682
return casts;

stringdtype/stringdtype/src/dtype.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ string_discover_descriptor_from_pyobject(PyTypeObject *cls, PyObject *obj)
130130

131131
// Take a python object `obj` and insert it into the array of dtype `descr` at
132132
// the position given by dataptr.
133-
static int
133+
int
134134
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr)
135135
{
136136
// borrow reference

stringdtype/stringdtype/src/dtype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ compare(void *, void *, void *);
4242
int
4343
init_string_na_object(PyObject *mod);
4444

45+
int
46+
stringdtype_setitem(StringDTypeObject *descr, PyObject *obj, char **dataptr);
47+
4548
// from dtypemeta.h, not public in numpy
4649
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
4750

0 commit comments

Comments
 (0)