Skip to content

Commit 3126b97

Browse files
authored
Merge pull request numpy#27636 from ngoldbaum/fix-stringdtype-promoters
BUG: fixes for StringDType/unicode promoters
2 parents cbd0e41 + 7bc49e9 commit 3126b97

File tree

4 files changed

+218
-80
lines changed

4 files changed

+218
-80
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
* Fixed a number of issues around promotion for string ufuncs with StringDType
2+
arguments. Mixing StringDType and the fixed-width DTypes using the string
3+
ufuncs should now generate much more uniform results.

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 109 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,20 @@ string_expandtabs_strided_loop(PyArrayMethod_Context *context,
15981598
return -1;
15991599
}
16001600

1601+
static int
1602+
string_center_ljust_rjust_promoter(
1603+
PyObject *NPY_UNUSED(ufunc),
1604+
PyArray_DTypeMeta *const op_dtypes[],
1605+
PyArray_DTypeMeta *const signature[],
1606+
PyArray_DTypeMeta *new_op_dtypes[])
1607+
{
1608+
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_StringDType);
1609+
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_Int64DType);
1610+
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_StringDType);
1611+
new_op_dtypes[3] = NPY_DT_NewRef(&PyArray_StringDType);
1612+
return 0;
1613+
}
1614+
16011615
static NPY_CASTING
16021616
center_ljust_rjust_resolve_descriptors(
16031617
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
@@ -2595,10 +2609,17 @@ init_stringdtype_ufuncs(PyObject *umath)
25952609
"find", "rfind", "index", "rindex", "count",
25962610
};
25972611

2598-
PyArray_DTypeMeta *findlike_promoter_dtypes[] = {
2599-
&PyArray_StringDType, &PyArray_UnicodeDType,
2600-
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2601-
&PyArray_DefaultIntDType,
2612+
PyArray_DTypeMeta *findlike_promoter_dtypes[2][5] = {
2613+
{
2614+
&PyArray_StringDType, &PyArray_UnicodeDType,
2615+
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2616+
&PyArray_IntAbstractDType,
2617+
},
2618+
{
2619+
&PyArray_UnicodeDType, &PyArray_StringDType,
2620+
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2621+
&PyArray_IntAbstractDType,
2622+
},
26022623
};
26032624

26042625
find_like_function *findlike_functions[] = {
@@ -2618,11 +2639,12 @@ init_stringdtype_ufuncs(PyObject *umath)
26182639
return -1;
26192640
}
26202641

2621-
2622-
if (add_promoter(umath, findlike_names[i],
2623-
findlike_promoter_dtypes,
2624-
5, string_findlike_promoter) < 0) {
2625-
return -1;
2642+
for (int j=0; j<2; j++) {
2643+
if (add_promoter(umath, findlike_names[i],
2644+
findlike_promoter_dtypes[j],
2645+
5, string_findlike_promoter) < 0) {
2646+
return -1;
2647+
}
26262648
}
26272649
}
26282650

@@ -2636,10 +2658,17 @@ init_stringdtype_ufuncs(PyObject *umath)
26362658
"startswith", "endswith",
26372659
};
26382660

2639-
PyArray_DTypeMeta *startswith_endswith_promoter_dtypes[] = {
2640-
&PyArray_StringDType, &PyArray_UnicodeDType,
2641-
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2642-
&PyArray_BoolDType,
2661+
PyArray_DTypeMeta *startswith_endswith_promoter_dtypes[2][5] = {
2662+
{
2663+
&PyArray_StringDType, &PyArray_UnicodeDType,
2664+
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2665+
&PyArray_BoolDType,
2666+
},
2667+
{
2668+
&PyArray_UnicodeDType, &PyArray_StringDType,
2669+
&PyArray_IntAbstractDType, &PyArray_IntAbstractDType,
2670+
&PyArray_BoolDType,
2671+
},
26432672
};
26442673

26452674
static STARTPOSITION startswith_endswith_startposition[] = {
@@ -2656,11 +2685,12 @@ init_stringdtype_ufuncs(PyObject *umath)
26562685
return -1;
26572686
}
26582687

2659-
2660-
if (add_promoter(umath, startswith_endswith_names[i],
2661-
startswith_endswith_promoter_dtypes,
2662-
5, string_startswith_endswith_promoter) < 0) {
2663-
return -1;
2688+
for (int j=0; j<2; j++) {
2689+
if (add_promoter(umath, startswith_endswith_names[i],
2690+
startswith_endswith_promoter_dtypes[j],
2691+
5, string_startswith_endswith_promoter) < 0) {
2692+
return -1;
2693+
}
26642694
}
26652695
}
26662696

@@ -2732,24 +2762,38 @@ init_stringdtype_ufuncs(PyObject *umath)
27322762
return -1;
27332763
}
27342764

2735-
PyArray_DTypeMeta *replace_promoter_pyint_dtypes[] = {
2736-
&PyArray_StringDType, &PyArray_UnicodeDType, &PyArray_UnicodeDType,
2737-
&PyArray_IntAbstractDType, &PyArray_StringDType,
2738-
};
2739-
2740-
if (add_promoter(umath, "_replace", replace_promoter_pyint_dtypes, 5,
2741-
string_replace_promoter) < 0) {
2742-
return -1;
2743-
}
2744-
2745-
PyArray_DTypeMeta *replace_promoter_int64_dtypes[] = {
2746-
&PyArray_StringDType, &PyArray_UnicodeDType, &PyArray_UnicodeDType,
2747-
&PyArray_Int64DType, &PyArray_StringDType,
2765+
PyArray_DTypeMeta *replace_promoter_unicode_dtypes[6][5] = {
2766+
{
2767+
&PyArray_StringDType, &PyArray_UnicodeDType, &PyArray_UnicodeDType,
2768+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2769+
},
2770+
{
2771+
&PyArray_UnicodeDType, &PyArray_StringDType, &PyArray_UnicodeDType,
2772+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2773+
},
2774+
{
2775+
&PyArray_UnicodeDType, &PyArray_UnicodeDType, &PyArray_StringDType,
2776+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2777+
},
2778+
{
2779+
&PyArray_StringDType, &PyArray_StringDType, &PyArray_UnicodeDType,
2780+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2781+
},
2782+
{
2783+
&PyArray_StringDType, &PyArray_UnicodeDType, &PyArray_StringDType,
2784+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2785+
},
2786+
{
2787+
&PyArray_UnicodeDType, &PyArray_StringDType, &PyArray_StringDType,
2788+
&PyArray_IntAbstractDType, &PyArray_StringDType,
2789+
},
27482790
};
27492791

2750-
if (add_promoter(umath, "_replace", replace_promoter_int64_dtypes, 5,
2751-
string_replace_promoter) < 0) {
2752-
return -1;
2792+
for (int j=0; j<6; j++) {
2793+
if (add_promoter(umath, "_replace", replace_promoter_unicode_dtypes[j], 5,
2794+
string_replace_promoter) < 0) {
2795+
return -1;
2796+
}
27532797
}
27542798

27552799
PyArray_DTypeMeta *expandtabs_dtypes[] = {
@@ -2767,9 +2811,9 @@ init_stringdtype_ufuncs(PyObject *umath)
27672811
}
27682812

27692813
PyArray_DTypeMeta *expandtabs_promoter_dtypes[] = {
2770-
&PyArray_StringDType,
2771-
(PyArray_DTypeMeta *)Py_None,
2772-
&PyArray_StringDType
2814+
&PyArray_StringDType,
2815+
&PyArray_IntAbstractDType,
2816+
&PyArray_StringDType
27732817
};
27742818

27752819
if (add_promoter(umath, "_expandtabs", expandtabs_promoter_dtypes,
@@ -2801,30 +2845,33 @@ init_stringdtype_ufuncs(PyObject *umath)
28012845
return -1;
28022846
}
28032847

2804-
PyArray_DTypeMeta *int_promoter_dtypes[] = {
2805-
&PyArray_StringDType,
2806-
(PyArray_DTypeMeta *)Py_None,
2807-
&PyArray_StringDType,
2808-
&PyArray_StringDType,
2809-
};
2810-
2811-
if (add_promoter(umath, center_ljust_rjust_names[i],
2812-
int_promoter_dtypes, 4,
2813-
string_multiply_promoter) < 0) {
2814-
return -1;
2815-
}
2816-
2817-
PyArray_DTypeMeta *unicode_promoter_dtypes[] = {
2818-
&PyArray_StringDType,
2819-
(PyArray_DTypeMeta *)Py_None,
2820-
&PyArray_UnicodeDType,
2821-
&PyArray_StringDType,
2848+
PyArray_DTypeMeta *promoter_dtypes[3][4] = {
2849+
{
2850+
&PyArray_StringDType,
2851+
&PyArray_IntAbstractDType,
2852+
&PyArray_StringDType,
2853+
&PyArray_StringDType,
2854+
},
2855+
{
2856+
&PyArray_StringDType,
2857+
&PyArray_IntAbstractDType,
2858+
&PyArray_UnicodeDType,
2859+
&PyArray_StringDType,
2860+
},
2861+
{
2862+
&PyArray_UnicodeDType,
2863+
&PyArray_IntAbstractDType,
2864+
&PyArray_StringDType,
2865+
&PyArray_StringDType,
2866+
},
28222867
};
28232868

2824-
if (add_promoter(umath, center_ljust_rjust_names[i],
2825-
unicode_promoter_dtypes, 4,
2826-
string_multiply_promoter) < 0) {
2827-
return -1;
2869+
for (int j=0; j<3; j++) {
2870+
if (add_promoter(umath, center_ljust_rjust_names[i],
2871+
promoter_dtypes[j], 4,
2872+
string_center_ljust_rjust_promoter) < 0) {
2873+
return -1;
2874+
}
28282875
}
28292876
}
28302877

@@ -2840,13 +2887,13 @@ init_stringdtype_ufuncs(PyObject *umath)
28402887
return -1;
28412888
}
28422889

2843-
PyArray_DTypeMeta *int_promoter_dtypes[] = {
2890+
PyArray_DTypeMeta *zfill_promoter_dtypes[] = {
28442891
&PyArray_StringDType,
2845-
(PyArray_DTypeMeta *)Py_None,
2892+
&PyArray_IntAbstractDType,
28462893
&PyArray_StringDType,
28472894
};
28482895

2849-
if (add_promoter(umath, "_zfill", int_promoter_dtypes, 3,
2896+
if (add_promoter(umath, "_zfill", zfill_promoter_dtypes, 3,
28502897
string_multiply_promoter) < 0) {
28512898
return -1;
28522899
}

0 commit comments

Comments
 (0)