Skip to content

Commit 9312d5e

Browse files
authored
Merge pull request numpy#27164 from seberg/einsum-argparse
MAINT: use npy_argparse for einsum
2 parents 32a2304 + bbf0ff4 commit 9312d5e

File tree

1 file changed

+51
-96
lines changed

1 file changed

+51
-96
lines changed

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 51 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,13 +2704,13 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *const *args, Py_ssize_t len_ar
27042704
}
27052705

27062706
static int
2707-
einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
2708-
PyArrayObject **op)
2707+
einsum_sub_op_from_str(
2708+
Py_ssize_t nargs, PyObject *const *args,
2709+
PyObject **str_obj, char **subscripts, PyArrayObject **op)
27092710
{
2710-
int i, nop;
2711+
Py_ssize_t nop = nargs - 1;
27112712
PyObject *subscripts_str;
27122713

2713-
nop = PyTuple_GET_SIZE(args) - 1;
27142714
if (nop <= 0) {
27152715
PyErr_SetString(PyExc_ValueError,
27162716
"must specify the einstein sum subscripts string "
@@ -2723,7 +2723,7 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27232723
}
27242724

27252725
/* Get the subscripts string */
2726-
subscripts_str = PyTuple_GET_ITEM(args, 0);
2726+
subscripts_str = args[0];
27272727
if (PyUnicode_Check(subscripts_str)) {
27282728
*str_obj = PyUnicode_AsASCIIString(subscripts_str);
27292729
if (*str_obj == NULL) {
@@ -2740,15 +2740,13 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27402740
}
27412741

27422742
/* Set the operands to NULL */
2743-
for (i = 0; i < nop; ++i) {
2743+
for (Py_ssize_t i = 0; i < nop; ++i) {
27442744
op[i] = NULL;
27452745
}
27462746

27472747
/* Get the operands */
2748-
for (i = 0; i < nop; ++i) {
2749-
PyObject *obj = PyTuple_GET_ITEM(args, i+1);
2750-
2751-
op[i] = (PyArrayObject *)PyArray_FROM_OF(obj, NPY_ARRAY_ENSUREARRAY);
2748+
for (Py_ssize_t i = 0; i < nop; ++i) {
2749+
op[i] = (PyArrayObject *)PyArray_FROM_OF(args[i+1], NPY_ARRAY_ENSUREARRAY);
27522750
if (op[i] == NULL) {
27532751
goto fail;
27542752
}
@@ -2757,7 +2755,7 @@ einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
27572755
return nop;
27582756

27592757
fail:
2760-
for (i = 0; i < nop; ++i) {
2758+
for (Py_ssize_t i = 0; i < nop; ++i) {
27612759
Py_XDECREF(op[i]);
27622760
op[i] = NULL;
27632761
}
@@ -2861,13 +2859,12 @@ einsum_list_to_subscripts(PyObject *obj, char *subscripts, int subsize)
28612859
* Returns -1 on error, number of operands placed in op otherwise.
28622860
*/
28632861
static int
2864-
einsum_sub_op_from_lists(PyObject *args,
2865-
char *subscripts, int subsize, PyArrayObject **op)
2862+
einsum_sub_op_from_lists(Py_ssize_t nargs, PyObject *const *args,
2863+
char *subscripts, int subsize, PyArrayObject **op)
28662864
{
28672865
int subindex = 0;
2868-
npy_intp i, nop;
28692866

2870-
nop = PyTuple_Size(args)/2;
2867+
Py_ssize_t nop = nargs / 2;
28712868

28722869
if (nop == 0) {
28732870
PyErr_SetString(PyExc_ValueError, "must provide at least an "
@@ -2880,15 +2877,12 @@ einsum_sub_op_from_lists(PyObject *args,
28802877
}
28812878

28822879
/* Set the operands to NULL */
2883-
for (i = 0; i < nop; ++i) {
2880+
for (Py_ssize_t i = 0; i < nop; ++i) {
28842881
op[i] = NULL;
28852882
}
28862883

28872884
/* Get the operands and build the subscript string */
2888-
for (i = 0; i < nop; ++i) {
2889-
PyObject *obj = PyTuple_GET_ITEM(args, 2*i);
2890-
int n;
2891-
2885+
for (Py_ssize_t i = 0; i < nop; ++i) {
28922886
/* Comma between the subscripts for each operand */
28932887
if (i != 0) {
28942888
subscripts[subindex++] = ',';
@@ -2899,25 +2893,21 @@ einsum_sub_op_from_lists(PyObject *args,
28992893
}
29002894
}
29012895

2902-
op[i] = (PyArrayObject *)PyArray_FROM_OF(obj, NPY_ARRAY_ENSUREARRAY);
2896+
op[i] = (PyArrayObject *)PyArray_FROM_OF(args[2*i], NPY_ARRAY_ENSUREARRAY);
29032897
if (op[i] == NULL) {
29042898
goto fail;
29052899
}
29062900

2907-
obj = PyTuple_GET_ITEM(args, 2*i+1);
2908-
n = einsum_list_to_subscripts(obj, subscripts+subindex,
2909-
subsize-subindex);
2901+
int n = einsum_list_to_subscripts(
2902+
args[2*i + 1], subscripts+subindex, subsize-subindex);
29102903
if (n < 0) {
29112904
goto fail;
29122905
}
29132906
subindex += n;
29142907
}
29152908

29162909
/* Add the '->' to the string if provided */
2917-
if (PyTuple_Size(args) == 2*nop+1) {
2918-
PyObject *obj;
2919-
int n;
2920-
2910+
if (nargs == 2*nop+1) {
29212911
if (subindex + 2 >= subsize) {
29222912
PyErr_SetString(PyExc_ValueError,
29232913
"subscripts list is too long");
@@ -2926,9 +2916,8 @@ einsum_sub_op_from_lists(PyObject *args,
29262916
subscripts[subindex++] = '-';
29272917
subscripts[subindex++] = '>';
29282918

2929-
obj = PyTuple_GET_ITEM(args, 2*nop);
2930-
n = einsum_list_to_subscripts(obj, subscripts+subindex,
2931-
subsize-subindex);
2919+
int n = einsum_list_to_subscripts(
2920+
args[2*nop], subscripts+subindex, subsize-subindex);
29322921
if (n < 0) {
29332922
goto fail;
29342923
}
@@ -2941,7 +2930,7 @@ einsum_sub_op_from_lists(PyObject *args,
29412930
return nop;
29422931

29432932
fail:
2944-
for (i = 0; i < nop; ++i) {
2933+
for (Py_ssize_t i = 0; i < nop; ++i) {
29452934
Py_XDECREF(op[i]);
29462935
op[i] = NULL;
29472936
}
@@ -2950,108 +2939,74 @@ einsum_sub_op_from_lists(PyObject *args,
29502939
}
29512940

29522941
static PyObject *
2953-
array_einsum(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
2942+
array_einsum(PyObject *NPY_UNUSED(dummy),
2943+
PyObject *const *args, Py_ssize_t nargsf, PyObject *kwnames)
29542944
{
29552945
char *subscripts = NULL, subscripts_buffer[256];
29562946
PyObject *str_obj = NULL, *str_key_obj = NULL;
2957-
PyObject *arg0;
2958-
int i, nop;
2947+
int nop;
29592948
PyArrayObject *op[NPY_MAXARGS];
29602949
NPY_ORDER order = NPY_KEEPORDER;
29612950
NPY_CASTING casting = NPY_SAFE_CASTING;
2951+
PyObject *out_obj = NULL;
29622952
PyArrayObject *out = NULL;
29632953
PyArray_Descr *dtype = NULL;
29642954
PyObject *ret = NULL;
2955+
NPY_PREPARE_ARGPARSER;
2956+
2957+
Py_ssize_t nargs = PyVectorcall_NARGS(nargsf);
29652958

2966-
if (PyTuple_GET_SIZE(args) < 1) {
2959+
if (nargs < 1) {
29672960
PyErr_SetString(PyExc_ValueError,
29682961
"must specify the einstein sum subscripts string "
29692962
"and at least one operand, or at least one operand "
29702963
"and its corresponding subscripts list");
29712964
return NULL;
29722965
}
2973-
arg0 = PyTuple_GET_ITEM(args, 0);
29742966

29752967
/* einsum('i,j', a, b), einsum('i,j->ij', a, b) */
2976-
if (PyBytes_Check(arg0) || PyUnicode_Check(arg0)) {
2977-
nop = einsum_sub_op_from_str(args, &str_obj, &subscripts, op);
2968+
if (PyBytes_Check(args[0]) || PyUnicode_Check(args[0])) {
2969+
nop = einsum_sub_op_from_str(nargs, args, &str_obj, &subscripts, op);
29782970
}
29792971
/* einsum(a, [0], b, [1]), einsum(a, [0], b, [1], [0,1]) */
29802972
else {
2981-
nop = einsum_sub_op_from_lists(args, subscripts_buffer,
2982-
sizeof(subscripts_buffer), op);
2973+
nop = einsum_sub_op_from_lists(nargs, args, subscripts_buffer,
2974+
sizeof(subscripts_buffer), op);
29832975
subscripts = subscripts_buffer;
29842976
}
29852977
if (nop <= 0) {
29862978
goto finish;
29872979
}
29882980

29892981
/* Get the keyword arguments */
2990-
if (kwds != NULL) {
2991-
PyObject *key, *value;
2992-
Py_ssize_t pos = 0;
2993-
while (PyDict_Next(kwds, &pos, &key, &value)) {
2994-
char *str = NULL;
2995-
2996-
Py_XDECREF(str_key_obj);
2997-
str_key_obj = PyUnicode_AsASCIIString(key);
2998-
if (str_key_obj != NULL) {
2999-
key = str_key_obj;
3000-
}
3001-
3002-
str = PyBytes_AsString(key);
3003-
3004-
if (str == NULL) {
3005-
PyErr_Clear();
3006-
PyErr_SetString(PyExc_TypeError, "invalid keyword");
3007-
goto finish;
3008-
}
3009-
3010-
if (strcmp(str,"out") == 0) {
3011-
if (PyArray_Check(value)) {
3012-
out = (PyArrayObject *)value;
3013-
}
3014-
else {
3015-
PyErr_SetString(PyExc_TypeError,
3016-
"keyword parameter out must be an "
3017-
"array for einsum");
3018-
goto finish;
3019-
}
3020-
}
3021-
else if (strcmp(str,"order") == 0) {
3022-
if (!PyArray_OrderConverter(value, &order)) {
3023-
goto finish;
3024-
}
3025-
}
3026-
else if (strcmp(str,"casting") == 0) {
3027-
if (!PyArray_CastingConverter(value, &casting)) {
3028-
goto finish;
3029-
}
3030-
}
3031-
else if (strcmp(str,"dtype") == 0) {
3032-
if (!PyArray_DescrConverter2(value, &dtype)) {
3033-
goto finish;
3034-
}
3035-
}
3036-
else {
3037-
PyErr_Format(PyExc_TypeError,
3038-
"'%s' is an invalid keyword for einsum",
3039-
str);
3040-
goto finish;
3041-
}
2982+
if (kwnames != NULL) {
2983+
if (npy_parse_arguments("einsum", args+nargs, 0, kwnames,
2984+
"$out", NULL, &out_obj,
2985+
"$order", &PyArray_OrderConverter, &order,
2986+
"$casting", &PyArray_CastingConverter, &casting,
2987+
"$dtype", &PyArray_DescrConverter2, &dtype,
2988+
NULL, NULL, NULL) < 0) {
2989+
goto finish;
30422990
}
2991+
if (out_obj != NULL && !PyArray_Check(out_obj)) {
2992+
PyErr_SetString(PyExc_TypeError,
2993+
"keyword parameter out must be an "
2994+
"array for einsum");
2995+
goto finish;
2996+
}
2997+
out = (PyArrayObject *)out_obj;
30432998
}
30442999

30453000
ret = (PyObject *)PyArray_EinsteinSum(subscripts, nop, op, dtype,
3046-
order, casting, out);
3001+
order, casting, out);
30473002

30483003
/* If no output was supplied, possibly convert to a scalar */
30493004
if (ret != NULL && out == NULL) {
30503005
ret = PyArray_Return((PyArrayObject *)ret);
30513006
}
30523007

30533008
finish:
3054-
for (i = 0; i < nop; ++i) {
3009+
for (Py_ssize_t i = 0; i < nop; ++i) {
30553010
Py_XDECREF(op[i]);
30563011
}
30573012
Py_XDECREF(dtype);
@@ -4518,7 +4473,7 @@ static struct PyMethodDef array_module_methods[] = {
45184473
METH_FASTCALL, NULL},
45194474
{"c_einsum",
45204475
(PyCFunction)array_einsum,
4521-
METH_VARARGS|METH_KEYWORDS, NULL},
4476+
METH_FASTCALL|METH_KEYWORDS, NULL},
45224477
{"correlate",
45234478
(PyCFunction)array_correlate,
45244479
METH_FASTCALL | METH_KEYWORDS, NULL},

0 commit comments

Comments
 (0)