Skip to content

Commit 0bf9c46

Browse files
authored
Merge pull request numpy#26744 from seberg/promote-new-dtypes
2 parents a153fb2 + 1d1c0c0 commit 0bf9c46

File tree

6 files changed

+86
-34
lines changed

6 files changed

+86
-34
lines changed

numpy/_core/src/multiarray/abstractdtypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ npy_mark_tmp_array_if_pyscalar(
4747
* a custom DType registered, and then we should use that.
4848
* Further, `np.float64` is a double subclass, so must reject it.
4949
*/
50+
// TODO,NOTE: This function should be changed to do exact long checks
51+
// For NumPy 2.1!
5052
if (PyLong_Check(obj)
5153
&& (PyArray_ISINTEGER(arr) || PyArray_ISOBJECT(arr))) {
5254
((PyArrayObject_fields *)arr)->flags |= NPY_ARRAY_WAS_PYTHON_INT;

numpy/_core/src/umath/dispatching.c

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "common.h"
4848
#include "npy_pycompat.h"
4949

50+
#include "arrayobject.h"
5051
#include "dispatching.h"
5152
#include "dtypemeta.h"
5253
#include "npy_hashtable.h"
@@ -64,7 +65,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
6465
PyArrayObject *const ops[],
6566
PyArray_DTypeMeta *signature[],
6667
PyArray_DTypeMeta *op_dtypes[],
67-
npy_bool allow_legacy_promotion);
68+
npy_bool legacy_promotion_is_possible);
6869

6970

7071
/**
@@ -759,7 +760,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
759760
PyArrayObject *const ops[],
760761
PyArray_DTypeMeta *signature[],
761762
PyArray_DTypeMeta *op_dtypes[],
762-
npy_bool allow_legacy_promotion)
763+
npy_bool legacy_promotion_is_possible)
763764
{
764765
/*
765766
* Fetch the dispatching info which consists of the implementation and
@@ -828,7 +829,7 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
828829
* However, we need to give the legacy implementation a chance here.
829830
* (it will modify `op_dtypes`).
830831
*/
831-
if (!allow_legacy_promotion || ufunc->type_resolver == NULL ||
832+
if (!legacy_promotion_is_possible || ufunc->type_resolver == NULL ||
832833
(ufunc->ntypes == 0 && ufunc->userloops == NULL)) {
833834
/* Already tried or not a "legacy" ufunc (no loop found, return) */
834835
return NULL;
@@ -935,11 +936,11 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
935936
PyArray_DTypeMeta *signature[],
936937
PyArray_DTypeMeta *op_dtypes[],
937938
npy_bool force_legacy_promotion,
938-
npy_bool allow_legacy_promotion,
939939
npy_bool promoting_pyscalars,
940940
npy_bool ensure_reduce_compatible)
941941
{
942942
int nin = ufunc->nin, nargs = ufunc->nargs;
943+
npy_bool legacy_promotion_is_possible = NPY_TRUE;
943944

944945
/*
945946
* Get the actual DTypes we operate with by setting op_dtypes[i] from
@@ -964,11 +965,20 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
964965
*/
965966
Py_CLEAR(op_dtypes[i]);
966967
}
968+
/*
969+
* If the op_dtype ends up being a non-legacy one, then we cannot use
970+
* legacy promotion (unless this is a python scalar).
971+
*/
972+
if (op_dtypes[i] != NULL && !NPY_DT_is_legacy(op_dtypes[i]) && (
973+
signature[i] != NULL || // signature cannot be a pyscalar
974+
!(PyArray_FLAGS(ops[i]) & NPY_ARRAY_WAS_PYTHON_LITERAL))) {
975+
legacy_promotion_is_possible = NPY_FALSE;
976+
}
967977
}
968978

969979
int current_promotion_state = get_npy_promotion_state();
970980

971-
if (force_legacy_promotion
981+
if (force_legacy_promotion && legacy_promotion_is_possible
972982
&& current_promotion_state == NPY_USE_LEGACY_PROMOTION
973983
&& (ufunc->ntypes != 0 || ufunc->userloops != NULL)) {
974984
/*
@@ -986,7 +996,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
986996
/* Pause warnings and always use "new" path */
987997
set_npy_promotion_state(NPY_USE_WEAK_PROMOTION);
988998
PyObject *info = promote_and_get_info_and_ufuncimpl(ufunc,
989-
ops, signature, op_dtypes, allow_legacy_promotion);
999+
ops, signature, op_dtypes, legacy_promotion_is_possible);
9901000
set_npy_promotion_state(current_promotion_state);
9911001

9921002
if (info == NULL) {
@@ -1032,7 +1042,7 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
10321042
Py_INCREF(signature[0]);
10331043
return promote_and_get_ufuncimpl(ufunc,
10341044
ops, signature, op_dtypes,
1035-
force_legacy_promotion, allow_legacy_promotion,
1045+
force_legacy_promotion,
10361046
promoting_pyscalars, NPY_FALSE);
10371047
}
10381048

numpy/_core/src/umath/dispatching.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
2222
PyArray_DTypeMeta *signature[],
2323
PyArray_DTypeMeta *op_dtypes[],
2424
npy_bool force_legacy_promotion,
25-
npy_bool allow_legacy_promotion,
2625
npy_bool promote_pyscalars,
2726
npy_bool ensure_reduce_compatible);
2827

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,25 @@ all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
10281028
PyArray_DTypeMeta *const signature[],
10291029
PyArray_DTypeMeta *new_op_dtypes[])
10301030
{
1031+
if ((op_dtypes[0] != &PyArray_StringDType &&
1032+
op_dtypes[1] != &PyArray_StringDType &&
1033+
op_dtypes[2] != &PyArray_StringDType)) {
1034+
/*
1035+
* This promoter was triggered with only unicode arguments, so use
1036+
* unicode. This can happen due to `dtype=` support which sets the
1037+
* output DType/signature.
1038+
*/
1039+
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1040+
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1041+
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_UnicodeDType);
1042+
return 0;
1043+
}
1044+
if ((signature[0] == &PyArray_UnicodeDType &&
1045+
signature[1] == &PyArray_UnicodeDType &&
1046+
signature[2] == &PyArray_UnicodeDType)) {
1047+
/* Unicode forced, but didn't override a string input: invalid */
1048+
return -1;
1049+
}
10311050
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_StringDType);
10321051
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_StringDType);
10331052
new_op_dtypes[2] = NPY_DT_NewRef(&PyArray_StringDType);
@@ -2532,6 +2551,17 @@ init_stringdtype_ufuncs(PyObject *umath)
25322551
return -1;
25332552
}
25342553

2554+
PyArray_DTypeMeta *out_strings_promoter_dtypes[] = {
2555+
&PyArray_UnicodeDType,
2556+
&PyArray_UnicodeDType,
2557+
&PyArray_StringDType,
2558+
};
2559+
2560+
if (add_promoter(umath, "add", out_strings_promoter_dtypes, 3,
2561+
all_strings_promoter) < 0) {
2562+
return -1;
2563+
}
2564+
25352565
INIT_MULTIPLY(Int64, int64);
25362566
INIT_MULTIPLY(UInt64, uint64);
25372567

numpy/_core/src/umath/ufunc_object.c

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ static int
606606
convert_ufunc_arguments(PyUFuncObject *ufunc,
607607
ufunc_full_args full_args, PyArrayObject *out_op[],
608608
PyArray_DTypeMeta *out_op_DTypes[],
609-
npy_bool *force_legacy_promotion, npy_bool *allow_legacy_promotion,
609+
npy_bool *force_legacy_promotion,
610610
npy_bool *promoting_pyscalars,
611611
PyObject *order_obj, NPY_ORDER *out_order,
612612
PyObject *casting_obj, NPY_CASTING *out_casting,
@@ -622,7 +622,6 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
622622
/* Convert and fill in input arguments */
623623
npy_bool all_scalar = NPY_TRUE;
624624
npy_bool any_scalar = NPY_FALSE;
625-
*allow_legacy_promotion = NPY_TRUE;
626625
*force_legacy_promotion = NPY_FALSE;
627626
*promoting_pyscalars = NPY_FALSE;
628627
for (int i = 0; i < nin; i++) {
@@ -657,11 +656,6 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
657656
break;
658657
}
659658

660-
if (!NPY_DT_is_legacy(out_op_DTypes[i])) {
661-
*allow_legacy_promotion = NPY_FALSE;
662-
// TODO: A subclass of int, float, complex could reach here and
663-
// it should not be flagged as "weak" if it does.
664-
}
665659
if (PyArray_NDIM(out_op[i]) == 0) {
666660
any_scalar = NPY_TRUE;
667661
}
@@ -707,7 +701,7 @@ convert_ufunc_arguments(PyUFuncObject *ufunc,
707701
*promoting_pyscalars = NPY_TRUE;
708702
}
709703
}
710-
if (*allow_legacy_promotion && (!all_scalar && any_scalar)) {
704+
if ((!all_scalar && any_scalar)) {
711705
*force_legacy_promotion = should_use_min_scalar(nin, out_op, 0, NULL);
712706
}
713707

@@ -2351,8 +2345,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
23512345
}
23522346

23532347
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
2354-
ops, signature, operation_DTypes, NPY_FALSE, NPY_TRUE,
2355-
NPY_FALSE, NPY_TRUE);
2348+
ops, signature, operation_DTypes, NPY_FALSE, NPY_FALSE, NPY_TRUE);
23562349
if (evil_ndim_mutating_hack) {
23572350
((PyArrayObject_fields *)out)->nd = 0;
23582351
}
@@ -4433,13 +4426,12 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
44334426
npy_bool subok = NPY_TRUE;
44344427
int keepdims = -1; /* We need to know if it was passed */
44354428
npy_bool force_legacy_promotion;
4436-
npy_bool allow_legacy_promotion;
44374429
npy_bool promoting_pyscalars;
44384430
if (convert_ufunc_arguments(ufunc,
44394431
/* extract operand related information: */
44404432
full_args, operands,
44414433
operand_DTypes,
4442-
&force_legacy_promotion, &allow_legacy_promotion,
4434+
&force_legacy_promotion,
44434435
&promoting_pyscalars,
44444436
/* extract general information: */
44454437
order_obj, &order,
@@ -4460,7 +4452,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
44604452
*/
44614453
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
44624454
operands, signature,
4463-
operand_DTypes, force_legacy_promotion, allow_legacy_promotion,
4455+
operand_DTypes, force_legacy_promotion,
44644456
promoting_pyscalars, NPY_FALSE);
44654457
if (ufuncimpl == NULL) {
44664458
goto fail;
@@ -5790,22 +5782,20 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
57905782
operand_DTypes[0] = NPY_DTYPE(PyArray_DESCR(op1_array));
57915783
Py_INCREF(operand_DTypes[0]);
57925784
int force_legacy_promotion = 0;
5793-
int allow_legacy_promotion = NPY_DT_is_legacy(operand_DTypes[0]);
57945785

57955786
if (op2_array != NULL) {
57965787
tmp_operands[1] = op2_array;
57975788
operand_DTypes[1] = NPY_DTYPE(PyArray_DESCR(op2_array));
57985789
Py_INCREF(operand_DTypes[1]);
5799-
allow_legacy_promotion &= NPY_DT_is_legacy(operand_DTypes[1]);
58005790
tmp_operands[2] = tmp_operands[0];
58015791
operand_DTypes[2] = operand_DTypes[0];
58025792
Py_INCREF(operand_DTypes[2]);
58035793

5804-
if (allow_legacy_promotion && ((PyArray_NDIM(op1_array) == 0)
5805-
!= (PyArray_NDIM(op2_array) == 0))) {
5806-
/* both are legacy and only one is 0-D: force legacy */
5807-
force_legacy_promotion = should_use_min_scalar(2, tmp_operands, 0, NULL);
5808-
}
5794+
if ((PyArray_NDIM(op1_array) == 0)
5795+
!= (PyArray_NDIM(op2_array) == 0)) {
5796+
/* both are legacy and only one is 0-D: force legacy */
5797+
force_legacy_promotion = should_use_min_scalar(2, tmp_operands, 0, NULL);
5798+
}
58095799
}
58105800
else {
58115801
tmp_operands[1] = tmp_operands[0];
@@ -5816,7 +5806,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58165806

58175807
ufuncimpl = promote_and_get_ufuncimpl(ufunc, tmp_operands, signature,
58185808
operand_DTypes, force_legacy_promotion,
5819-
allow_legacy_promotion, NPY_FALSE, NPY_FALSE);
5809+
NPY_FALSE, NPY_FALSE);
58205810
if (ufuncimpl == NULL) {
58215811
for (int i = 0; i < 3; i++) {
58225812
Py_XDECREF(signature[i]);
@@ -6058,7 +6048,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
60586048
set_npy_promotion_state(NPY_USE_WEAK_PROMOTION);
60596049

60606050
npy_bool promoting_pyscalars = NPY_FALSE;
6061-
npy_bool allow_legacy_promotion = NPY_TRUE;
60626051

60636052
if (_get_fixed_signature(ufunc, NULL, signature_obj, signature) < 0) {
60646053
goto finish;
@@ -6091,9 +6080,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
60916080
}
60926081
DTypes[i] = NPY_DTYPE(descr);
60936082
Py_INCREF(DTypes[i]);
6094-
if (!NPY_DT_is_legacy(DTypes[i])) {
6095-
allow_legacy_promotion = NPY_FALSE;
6096-
}
60976083
}
60986084
/* Explicitly allow int, float, and complex for the "weak" types. */
60996085
else if (descr_obj == (PyObject *)&PyLong_Type) {
@@ -6149,7 +6135,7 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
61496135
if (!reduction) {
61506136
ufuncimpl = promote_and_get_ufuncimpl(ufunc,
61516137
dummy_arrays, signature, DTypes, NPY_FALSE,
6152-
allow_legacy_promotion, promoting_pyscalars, NPY_FALSE);
6138+
promoting_pyscalars, NPY_FALSE);
61536139
if (ufuncimpl == NULL) {
61546140
goto finish;
61556141
}

numpy/_core/tests/test_stringdtype.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,31 @@ def test_add_promoter(string_list):
828828
assert_array_equal(op + arr, lresult)
829829
assert_array_equal(arr + op, rresult)
830830

831+
# The promoter should be able to handle things if users pass `dtype=`
832+
res = np.add("hello", string_list, dtype=StringDType)
833+
assert res.dtype == StringDType()
834+
835+
# The promoter should not kick in if users override the input,
836+
# which means arr is cast, this fails because of the unknown length.
837+
with pytest.raises(TypeError, match="cannot cast dtype"):
838+
np.add(arr, "add", signature=("U", "U", None), casting="unsafe")
839+
840+
# But it must simply reject the following:
841+
with pytest.raises(TypeError, match=".*did not contain a loop"):
842+
np.add(arr, "add", signature=(None, "U", None))
843+
844+
with pytest.raises(TypeError, match=".*did not contain a loop"):
845+
np.add("a", "b", signature=("U", "U", StringDType))
846+
847+
848+
def test_add_no_legacy_promote_with_signature():
849+
# Possibly misplaced, but useful to test with string DType. We check that
850+
# if there is clearly no loop found, a stray `dtype=` doesn't break things
851+
# Regression test for the bad error in gh-26735
852+
# (If legacy promotion is gone, this can be deleted...)
853+
with pytest.raises(TypeError, match=".*did not contain a loop"):
854+
np.add("3", 6, dtype=StringDType)
855+
831856

832857
def test_add_promoter_reduce():
833858
# Exact TypeError could change, but ensure StringDtype doesn't match

0 commit comments

Comments
 (0)