@@ -456,7 +456,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
456
456
}
457
457
Py_XDECREF (common );
458
458
459
- /* Otherwise, set all input operands to StringDType */
459
+ /* Otherwise, set all input operands to final_dtype */
460
460
for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
461
461
PyArray_DTypeMeta * tmp = final_dtype ;
462
462
if (signature [i ]) {
@@ -474,21 +474,32 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
474
474
}
475
475
476
476
static int
477
- string_ufunc_promoter ( PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
478
- PyArray_DTypeMeta * signature [],
479
- PyArray_DTypeMeta * new_op_dtypes [])
477
+ string_object_promoter ( PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
478
+ PyArray_DTypeMeta * signature [],
479
+ PyArray_DTypeMeta * new_op_dtypes [])
480
480
{
481
- return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
481
+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
482
+ signature , new_op_dtypes ,
483
+ (PyArray_DTypeMeta * )& PyArray_ObjectDType );
484
+ }
485
+
486
+ static int
487
+ string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
488
+ PyArray_DTypeMeta * signature [],
489
+ PyArray_DTypeMeta * new_op_dtypes [])
490
+ {
491
+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
492
+ signature , new_op_dtypes ,
482
493
(PyArray_DTypeMeta * )& StringDType );
483
494
}
484
495
485
496
static int
486
- pandas_string_ufunc_promoter (PyUFuncObject * ufunc ,
487
- PyArray_DTypeMeta * op_dtypes [],
488
- PyArray_DTypeMeta * signature [],
489
- PyArray_DTypeMeta * new_op_dtypes [])
497
+ pandas_string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
498
+ PyArray_DTypeMeta * signature [],
499
+ PyArray_DTypeMeta * new_op_dtypes [])
490
500
{
491
- return ufunc_promoter_internal (ufunc , op_dtypes , signature , new_op_dtypes ,
501
+ return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
502
+ signature , new_op_dtypes ,
492
503
(PyArray_DTypeMeta * )& PandasStringDType );
493
504
}
494
505
@@ -538,7 +549,7 @@ init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
538
549
int
539
550
add_promoter (PyObject * numpy , const char * ufunc_name ,
540
551
PyArray_DTypeMeta * ldtype , PyArray_DTypeMeta * rdtype ,
541
- PyArray_DTypeMeta * edtype , int is_pandas )
552
+ PyArray_DTypeMeta * edtype , promoter_function * promoter_impl )
542
553
{
543
554
PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
544
555
@@ -553,16 +564,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
553
564
return -1 ;
554
565
}
555
566
556
- PyObject * promoter_capsule = NULL ;
557
-
558
- if (is_pandas == 0 ) {
559
- promoter_capsule = PyCapsule_New ((void * )& string_ufunc_promoter ,
560
- "numpy._ufunc_promoter" , NULL );
561
- }
562
- else {
563
- promoter_capsule = PyCapsule_New ((void * )& pandas_string_ufunc_promoter ,
564
- "numpy._ufunc_promoter" , NULL );
565
- }
567
+ PyObject * promoter_capsule = PyCapsule_New ((void * )promoter_impl ,
568
+ "numpy._ufunc_promoter" , NULL );
566
569
567
570
if (promoter_capsule == NULL ) {
568
571
Py_DECREF (ufunc );
@@ -592,21 +595,31 @@ init_ufuncs(void)
592
595
return -1 ;
593
596
}
594
597
595
- StringDType_type * * dtype_classes = NULL ;
596
598
int num_dtypes ;
597
599
598
600
if (PANDAS_AVAILABLE ) {
599
- dtype_classes = malloc (sizeof (StringDType_type * ) * 2 );
600
- dtype_classes [0 ] = & StringDType ;
601
- dtype_classes [1 ] = & PandasStringDType ;
602
601
num_dtypes = 2 ;
603
602
}
604
603
else {
605
- dtype_classes = malloc (sizeof (StringDType_type * ) * 1 );
606
- dtype_classes [0 ] = & StringDType ;
607
604
num_dtypes = 1 ;
608
605
}
609
606
607
+ StringDType_type * * dtype_classes =
608
+ malloc (sizeof (StringDType_type * ) * num_dtypes );
609
+ promoter_function * * unicode_promoters =
610
+ malloc (sizeof (promoter_function * ) * num_dtypes );
611
+ dtype_classes [0 ] = & StringDType ;
612
+ unicode_promoters [0 ] = & string_unicode_promoter ;
613
+
614
+ if (PANDAS_AVAILABLE ) {
615
+ dtype_classes [1 ] = & PandasStringDType ;
616
+ unicode_promoters [1 ] = & pandas_string_unicode_promoter ;
617
+ }
618
+
619
+ static char * comparison_ufunc_names [6 ] = {"equal" , "not_equal" ,
620
+ "greater" , "greater_equal" ,
621
+ "less" , "less_equal" };
622
+
610
623
for (int di = 0 ; di < num_dtypes ; di ++ ) {
611
624
PyArray_DTypeMeta * comparison_dtypes [] = {
612
625
(PyArray_DTypeMeta * )dtype_classes [di ],
@@ -654,34 +667,32 @@ init_ufuncs(void)
654
667
goto error ;
655
668
}
656
669
657
- static char * ufunc_names [6 ] = {"equal" , "not_equal" ,
658
- "greater" , "greater_equal" ,
659
- "less" , "less_equal" };
660
-
661
670
for (int i = 0 ; i < 6 ; i ++ ) {
662
- if (add_promoter (numpy , ufunc_names [i ],
671
+ if (add_promoter (numpy , comparison_ufunc_names [i ],
663
672
(PyArray_DTypeMeta * )dtype_classes [di ],
664
673
& PyArray_UnicodeDType , & PyArray_BoolDType ,
665
- 0 ) < 0 ) {
674
+ unicode_promoters [ di ] ) < 0 ) {
666
675
goto error ;
667
676
}
668
677
669
- if (add_promoter (numpy , ufunc_names [i ], & PyArray_UnicodeDType ,
678
+ if (add_promoter (numpy , comparison_ufunc_names [i ],
679
+ & PyArray_UnicodeDType ,
670
680
(PyArray_DTypeMeta * )dtype_classes [di ],
671
- & PyArray_BoolDType , 0 ) < 0 ) {
681
+ & PyArray_BoolDType , unicode_promoters [ di ] ) < 0 ) {
672
682
goto error ;
673
683
}
674
684
675
- if (add_promoter (numpy , ufunc_names [i ], & PyArray_ObjectDType ,
676
- (PyArray_DTypeMeta * )dtype_classes [di ],
677
- & PyArray_BoolDType , 0 ) < 0 ) {
685
+ if (add_promoter (
686
+ numpy , comparison_ufunc_names [i ], & PyArray_ObjectDType ,
687
+ (PyArray_DTypeMeta * )dtype_classes [di ],
688
+ & PyArray_BoolDType , & string_object_promoter ) < 0 ) {
678
689
goto error ;
679
690
}
680
691
681
- if (add_promoter (numpy , ufunc_names [i ],
692
+ if (add_promoter (numpy , comparison_ufunc_names [i ],
682
693
(PyArray_DTypeMeta * )dtype_classes [di ],
683
694
& PyArray_ObjectDType , & PyArray_BoolDType ,
684
- 0 ) < 0 ) {
695
+ & string_object_promoter ) < 0 ) {
685
696
goto error ;
686
697
}
687
698
}
@@ -720,10 +731,36 @@ init_ufuncs(void)
720
731
}
721
732
}
722
733
734
+ // add promoters for all ufuncs so comparison operations mixing StringDType
735
+ // and PandasStringDType work correctly.
736
+
737
+ if (PANDAS_AVAILABLE ) {
738
+ for (int i = 0 ; i < 6 ; i ++ ) {
739
+ if (add_promoter (numpy , comparison_ufunc_names [i ],
740
+ (PyArray_DTypeMeta * )& StringDType ,
741
+ (PyArray_DTypeMeta * )& PandasStringDType ,
742
+ & PyArray_BoolDType ,
743
+ string_unicode_promoter ) < 0 ) {
744
+ goto error ;
745
+ }
746
+
747
+ if (add_promoter (numpy , comparison_ufunc_names [i ],
748
+ (PyArray_DTypeMeta * )& PandasStringDType ,
749
+ (PyArray_DTypeMeta * )& StringDType ,
750
+ & PyArray_BoolDType ,
751
+ string_unicode_promoter ) < 0 ) {
752
+ goto error ;
753
+ }
754
+ }
755
+ }
756
+ free (dtype_classes );
757
+ free (unicode_promoters );
723
758
Py_DECREF (numpy );
724
759
return 0 ;
725
760
726
761
error :
762
+ free (dtype_classes );
763
+ free (unicode_promoters );
727
764
Py_DECREF (numpy );
728
765
return -1 ;
729
766
}
0 commit comments