Skip to content

Commit 3685895

Browse files
committed
promote string to object in comparisons
1 parent e6257e9 commit 3685895

File tree

3 files changed

+92
-172
lines changed

3 files changed

+92
-172
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 6 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -449,124 +449,6 @@ static PyType_Slot s2b_slots[] = {
449449

450450
static char *s2b_name = "cast_StringDType_to_Bool";
451451

452-
// object to string
453-
454-
typedef struct {
455-
NpyAuxData base;
456-
PyArray_Descr *descr;
457-
int move_references;
458-
} _object_to_string_auxdata;
459-
460-
static void
461-
_object_to_string_auxdata_free(NpyAuxData *auxdata)
462-
{
463-
_object_to_string_auxdata *data = (_object_to_string_auxdata *)auxdata;
464-
Py_DECREF(data->descr);
465-
PyMem_Free(data);
466-
}
467-
468-
static NpyAuxData *
469-
_object_to_string_auxdata_clone(NpyAuxData *data)
470-
{
471-
_object_to_string_auxdata *res = PyMem_Malloc(sizeof(*res));
472-
if (res == NULL) {
473-
return NULL;
474-
}
475-
memcpy(res, data, sizeof(*res));
476-
Py_INCREF(res->descr);
477-
return (NpyAuxData *)res;
478-
}
479-
480-
static int
481-
object_to_string_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
482-
char *const *args, const npy_intp *dimensions,
483-
const npy_intp *strides, NpyAuxData *auxdata)
484-
{
485-
npy_intp N = dimensions[0];
486-
char *src = args[0], *dst = args[1];
487-
npy_intp src_stride = strides[0], dst_stride = strides[1];
488-
_object_to_string_auxdata *data = (_object_to_string_auxdata *)auxdata;
489-
490-
PyObject *src_ref;
491-
492-
while (N > 0) {
493-
memcpy(&src_ref, src, sizeof(src_ref));
494-
if (stringdtype_setitem((StringDTypeObject *)(data->descr),
495-
src_ref ? src_ref : Py_None,
496-
(void *)dst) < 0) {
497-
return -1;
498-
}
499-
500-
if (data->move_references && src_ref != NULL) {
501-
Py_DECREF(src_ref);
502-
memset(src, 0, sizeof(src_ref));
503-
}
504-
505-
N--;
506-
dst += dst_stride;
507-
src += src_stride;
508-
}
509-
return 0;
510-
}
511-
512-
NPY_NO_EXPORT int
513-
object_to_string_get_loop(PyArrayMethod_Context *context,
514-
int NPY_UNUSED(aligned), int move_references,
515-
const npy_intp *NPY_UNUSED(strides),
516-
PyArrayMethod_StridedLoop **out_loop,
517-
NpyAuxData **out_transferdata,
518-
NPY_ARRAYMETHOD_FLAGS *flags)
519-
{
520-
*flags = NPY_METH_REQUIRES_PYAPI;
521-
522-
/* NOTE: auxdata is only really necessary to flag `move_references` */
523-
_object_to_string_auxdata *data = PyMem_Malloc(sizeof(*data));
524-
if (data == NULL) {
525-
return -1;
526-
}
527-
data->base.free = &_object_to_string_auxdata_free;
528-
data->base.clone = &_object_to_string_auxdata_clone;
529-
530-
Py_INCREF(context->descriptors[1]);
531-
data->descr = context->descriptors[1];
532-
data->move_references = move_references;
533-
*out_transferdata = (NpyAuxData *)data;
534-
*out_loop = &object_to_string_strided_loop;
535-
return 0;
536-
}
537-
538-
static NPY_CASTING
539-
object_to_string_resolve_descriptors(PyArrayMethodObject *NPY_UNUSED(self),
540-
PyArray_DTypeMeta *dtypes[2],
541-
PyArray_Descr *given_descrs[2],
542-
PyArray_Descr *loop_descrs[2],
543-
npy_intp *NPY_UNUSED(view_offset))
544-
{
545-
if (given_descrs[1] == NULL) {
546-
loop_descrs[1] = (PyArray_Descr *)new_stringdtype_instance(
547-
(PyTypeObject *)dtypes[1]);
548-
if (loop_descrs[1] == NULL) {
549-
return -1;
550-
}
551-
}
552-
else {
553-
Py_INCREF(given_descrs[1]);
554-
loop_descrs[1] = given_descrs[1];
555-
}
556-
557-
Py_INCREF(given_descrs[0]);
558-
loop_descrs[0] = given_descrs[0];
559-
560-
return NPY_SAFE_CASTING;
561-
}
562-
563-
static PyType_Slot o2s_slots[] = {
564-
{NPY_METH_resolve_descriptors, &object_to_string_resolve_descriptors},
565-
{_NPY_METH_get_loop, &object_to_string_get_loop},
566-
{0, NULL}};
567-
568-
static char *o2s_name = "cast_object_to_StringDType";
569-
570452
PyArrayMethod_Spec *
571453
get_cast_spec(const char *name, NPY_CASTING casting,
572454
NPY_ARRAYMETHOD_FLAGS flags, PyArray_DTypeMeta **dtypes,
@@ -619,10 +501,10 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
619501

620502
int is_pandas = (this == (PyArray_DTypeMeta *)&PandasStringDType);
621503

622-
int num_casts = 6;
504+
int num_casts = 5;
623505

624506
if (is_pandas) {
625-
num_casts = 8;
507+
num_casts = 7;
626508

627509
PyArray_DTypeMeta **t2o_dtypes = get_dtypes(this, other);
628510

@@ -655,12 +537,6 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
655537
s2b_name, NPY_UNSAFE_CASTING, NPY_METH_NO_FLOATINGPOINT_ERRORS,
656538
s2b_dtypes, s2b_slots);
657539

658-
PyArray_DTypeMeta **o2s_dtypes = get_dtypes(&PyArray_ObjectDType, this);
659-
660-
PyArrayMethod_Spec *ObjectToStringCastSpec =
661-
get_cast_spec(o2s_name, NPY_SAFE_CASTING, NPY_METH_REQUIRES_PYAPI,
662-
o2s_dtypes, o2s_slots);
663-
664540
PyArrayMethod_Spec **casts = NULL;
665541

666542
casts = malloc(num_casts * sizeof(PyArrayMethod_Spec *));
@@ -669,14 +545,13 @@ get_casts(PyArray_DTypeMeta *this, PyArray_DTypeMeta *other)
669545
casts[1] = UnicodeToStringCastSpec;
670546
casts[2] = StringToUnicodeCastSpec;
671547
casts[3] = StringToBoolCastSpec;
672-
casts[4] = ObjectToStringCastSpec;
673548
if (is_pandas) {
674-
casts[5] = ThisToOtherCastSpec;
675-
casts[6] = OtherToThisCastSpec;
676-
casts[7] = NULL;
549+
casts[4] = ThisToOtherCastSpec;
550+
casts[5] = OtherToThisCastSpec;
551+
casts[6] = NULL;
677552
}
678553
else {
679-
casts[5] = NULL;
554+
casts[4] = NULL;
680555
}
681556

682557
return casts;

stringdtype/stringdtype/src/umath.c

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
456456
}
457457
Py_XDECREF(common);
458458

459-
/* Otherwise, set all input operands to StringDType */
459+
/* Otherwise, set all input operands to final_dtype */
460460
for (int i = 0; i < ufunc->nargs; i++) {
461461
PyArray_DTypeMeta *tmp = final_dtype;
462462
if (signature[i]) {
@@ -474,21 +474,32 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
474474
}
475475

476476
static int
477-
string_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
478-
PyArray_DTypeMeta *signature[],
479-
PyArray_DTypeMeta *new_op_dtypes[])
477+
string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
478+
PyArray_DTypeMeta *signature[],
479+
PyArray_DTypeMeta *new_op_dtypes[])
480480
{
481-
return ufunc_promoter_internal(ufunc, op_dtypes, signature, new_op_dtypes,
481+
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
482+
signature, new_op_dtypes,
483+
(PyArray_DTypeMeta *)&PyArray_ObjectDType);
484+
}
485+
486+
static int
487+
string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
488+
PyArray_DTypeMeta *signature[],
489+
PyArray_DTypeMeta *new_op_dtypes[])
490+
{
491+
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
492+
signature, new_op_dtypes,
482493
(PyArray_DTypeMeta *)&StringDType);
483494
}
484495

485496
static int
486-
pandas_string_ufunc_promoter(PyUFuncObject *ufunc,
487-
PyArray_DTypeMeta *op_dtypes[],
488-
PyArray_DTypeMeta *signature[],
489-
PyArray_DTypeMeta *new_op_dtypes[])
497+
pandas_string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
498+
PyArray_DTypeMeta *signature[],
499+
PyArray_DTypeMeta *new_op_dtypes[])
490500
{
491-
return ufunc_promoter_internal(ufunc, op_dtypes, signature, new_op_dtypes,
501+
return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
502+
signature, new_op_dtypes,
492503
(PyArray_DTypeMeta *)&PandasStringDType);
493504
}
494505

@@ -538,7 +549,7 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
538549
int
539550
add_promoter(PyObject *numpy, const char *ufunc_name,
540551
PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype,
541-
PyArray_DTypeMeta *edtype, int is_pandas)
552+
PyArray_DTypeMeta *edtype, promoter_function *promoter_impl)
542553
{
543554
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
544555

@@ -553,16 +564,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
553564
return -1;
554565
}
555566

556-
PyObject *promoter_capsule = NULL;
557-
558-
if (is_pandas == 0) {
559-
promoter_capsule = PyCapsule_New((void *)&string_ufunc_promoter,
560-
"numpy._ufunc_promoter", NULL);
561-
}
562-
else {
563-
promoter_capsule = PyCapsule_New((void *)&pandas_string_ufunc_promoter,
564-
"numpy._ufunc_promoter", NULL);
565-
}
567+
PyObject *promoter_capsule = PyCapsule_New((void *)promoter_impl,
568+
"numpy._ufunc_promoter", NULL);
566569

567570
if (promoter_capsule == NULL) {
568571
Py_DECREF(ufunc);
@@ -592,21 +595,31 @@ init_ufuncs(void)
592595
return -1;
593596
}
594597

595-
StringDType_type **dtype_classes = NULL;
596598
int num_dtypes;
597599

598600
if (PANDAS_AVAILABLE) {
599-
dtype_classes = malloc(sizeof(StringDType_type *) * 2);
600-
dtype_classes[0] = &StringDType;
601-
dtype_classes[1] = &PandasStringDType;
602601
num_dtypes = 2;
603602
}
604603
else {
605-
dtype_classes = malloc(sizeof(StringDType_type *) * 1);
606-
dtype_classes[0] = &StringDType;
607604
num_dtypes = 1;
608605
}
609606

607+
StringDType_type **dtype_classes =
608+
malloc(sizeof(StringDType_type *) * num_dtypes);
609+
promoter_function **unicode_promoters =
610+
malloc(sizeof(promoter_function *) * num_dtypes);
611+
dtype_classes[0] = &StringDType;
612+
unicode_promoters[0] = &string_unicode_promoter;
613+
614+
if (PANDAS_AVAILABLE) {
615+
dtype_classes[1] = &PandasStringDType;
616+
unicode_promoters[1] = &pandas_string_unicode_promoter;
617+
}
618+
619+
static char *comparison_ufunc_names[6] = {"equal", "not_equal",
620+
"greater", "greater_equal",
621+
"less", "less_equal"};
622+
610623
for (int di = 0; di < num_dtypes; di++) {
611624
PyArray_DTypeMeta *comparison_dtypes[] = {
612625
(PyArray_DTypeMeta *)dtype_classes[di],
@@ -654,34 +667,32 @@ init_ufuncs(void)
654667
goto error;
655668
}
656669

657-
static char *ufunc_names[6] = {"equal", "not_equal",
658-
"greater", "greater_equal",
659-
"less", "less_equal"};
660-
661670
for (int i = 0; i < 6; i++) {
662-
if (add_promoter(numpy, ufunc_names[i],
671+
if (add_promoter(numpy, comparison_ufunc_names[i],
663672
(PyArray_DTypeMeta *)dtype_classes[di],
664673
&PyArray_UnicodeDType, &PyArray_BoolDType,
665-
0) < 0) {
674+
unicode_promoters[di]) < 0) {
666675
goto error;
667676
}
668677

669-
if (add_promoter(numpy, ufunc_names[i], &PyArray_UnicodeDType,
678+
if (add_promoter(numpy, comparison_ufunc_names[i],
679+
&PyArray_UnicodeDType,
670680
(PyArray_DTypeMeta *)dtype_classes[di],
671-
&PyArray_BoolDType, 0) < 0) {
681+
&PyArray_BoolDType, unicode_promoters[di]) < 0) {
672682
goto error;
673683
}
674684

675-
if (add_promoter(numpy, ufunc_names[i], &PyArray_ObjectDType,
676-
(PyArray_DTypeMeta *)dtype_classes[di],
677-
&PyArray_BoolDType, 0) < 0) {
685+
if (add_promoter(
686+
numpy, comparison_ufunc_names[i], &PyArray_ObjectDType,
687+
(PyArray_DTypeMeta *)dtype_classes[di],
688+
&PyArray_BoolDType, &string_object_promoter) < 0) {
678689
goto error;
679690
}
680691

681-
if (add_promoter(numpy, ufunc_names[i],
692+
if (add_promoter(numpy, comparison_ufunc_names[i],
682693
(PyArray_DTypeMeta *)dtype_classes[di],
683694
&PyArray_ObjectDType, &PyArray_BoolDType,
684-
0) < 0) {
695+
&string_object_promoter) < 0) {
685696
goto error;
686697
}
687698
}
@@ -720,10 +731,36 @@ init_ufuncs(void)
720731
}
721732
}
722733

734+
// add promoters for all ufuncs so comparison operations mixing StringDType
735+
// and PandasStringDType work correctly.
736+
737+
if (PANDAS_AVAILABLE) {
738+
for (int i = 0; i < 6; i++) {
739+
if (add_promoter(numpy, comparison_ufunc_names[i],
740+
(PyArray_DTypeMeta *)&StringDType,
741+
(PyArray_DTypeMeta *)&PandasStringDType,
742+
&PyArray_BoolDType,
743+
string_unicode_promoter) < 0) {
744+
goto error;
745+
}
746+
747+
if (add_promoter(numpy, comparison_ufunc_names[i],
748+
(PyArray_DTypeMeta *)&PandasStringDType,
749+
(PyArray_DTypeMeta *)&StringDType,
750+
&PyArray_BoolDType,
751+
string_unicode_promoter) < 0) {
752+
goto error;
753+
}
754+
}
755+
}
756+
free(dtype_classes);
757+
free(unicode_promoters);
723758
Py_DECREF(numpy);
724759
return 0;
725760

726761
error:
762+
free(dtype_classes);
763+
free(unicode_promoters);
727764
Py_DECREF(numpy);
728765
return -1;
729766
}

0 commit comments

Comments
 (0)