|
13 | 13 | #include "string.h"
|
14 | 14 | #include "umath.h"
|
15 | 15 |
|
| 16 | +static NPY_CASTING |
| 17 | +multiply_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method), |
| 18 | + PyArray_DTypeMeta *NPY_UNUSED(dtypes[]), |
| 19 | + PyArray_Descr *given_descrs[], |
| 20 | + PyArray_Descr *loop_descrs[], |
| 21 | + npy_intp *NPY_UNUSED(view_offset)) |
| 22 | +{ |
| 23 | + Py_INCREF(given_descrs[0]); |
| 24 | + loop_descrs[0] = given_descrs[0]; |
| 25 | + Py_INCREF(given_descrs[1]); |
| 26 | + loop_descrs[1] = given_descrs[1]; |
| 27 | + Py_INCREF(given_descrs[0]); |
| 28 | + loop_descrs[2] = given_descrs[0]; |
| 29 | + |
| 30 | + return NPY_NO_CASTING; |
| 31 | +} |
| 32 | + |
| 33 | +static int |
| 34 | +multiply_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context), |
| 35 | + char *const data[], npy_intp const dimensions[], |
| 36 | + npy_intp const strides[], |
| 37 | + NpyAuxData *NPY_UNUSED(auxdata)) |
| 38 | +{ |
| 39 | + npy_intp N = dimensions[0]; |
| 40 | + char *in1 = data[0]; |
| 41 | + char *in2 = data[1]; |
| 42 | + char *out = data[2]; |
| 43 | + npy_intp in1_stride = strides[0]; |
| 44 | + npy_intp in2_stride = strides[1]; |
| 45 | + npy_intp out_stride = strides[2]; |
| 46 | + |
| 47 | + ss *s1 = NULL, *os = NULL; |
| 48 | + |
| 49 | + while (N--) { |
| 50 | + s1 = (ss *)in1; |
| 51 | + npy_int64 factor = *(npy_int64 *)in2; |
| 52 | + os = (ss *)out; |
| 53 | + npy_int64 newlen = (s1->len) * factor; |
| 54 | + |
| 55 | + ssfree(os); |
| 56 | + if (ssnewemptylen(newlen, os) < 0) { |
| 57 | + return -1; |
| 58 | + } |
| 59 | + |
| 60 | + for (int i = 0; i < factor; i++) { |
| 61 | + memcpy(os->buf + i * s1->len, s1->buf, s1->len); |
| 62 | + } |
| 63 | + os->buf[newlen] = '\0'; |
| 64 | + |
| 65 | + in1 += in1_stride; |
| 66 | + in2 += in2_stride; |
| 67 | + out += out_stride; |
| 68 | + } |
| 69 | + return 0; |
| 70 | +} |
| 71 | + |
16 | 72 | static NPY_CASTING
|
17 | 73 | binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
|
18 | 74 | PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
|
@@ -733,6 +789,18 @@ init_ufuncs(void)
|
733 | 789 | goto error;
|
734 | 790 | }
|
735 | 791 |
|
| 792 | + PyArray_DTypeMeta *multiply_types[] = { |
| 793 | + (PyArray_DTypeMeta *)&StringDType, |
| 794 | + &PyArray_Int64DType, |
| 795 | + (PyArray_DTypeMeta *)&StringDType |
| 796 | + }; |
| 797 | + |
| 798 | + if (init_ufunc(numpy, "multiply", multiply_types, |
| 799 | + &multiply_resolve_descriptors, &multiply_strided_loop, |
| 800 | + "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { |
| 801 | + goto error; |
| 802 | + } |
| 803 | + |
736 | 804 | Py_DECREF(numpy);
|
737 | 805 | return 0;
|
738 | 806 |
|
|
0 commit comments