|
13 | 13 | #include "string.h"
|
14 | 14 | #include "umath.h"
|
15 | 15 |
|
| 16 | +static NPY_CASTING |
| 17 | +multiply_resolve_descriptors( |
| 18 | + struct PyArrayMethodObject_tag *NPY_UNUSED(method), |
| 19 | + PyArray_DTypeMeta *dtypes[], PyArray_Descr *given_descrs[], |
| 20 | + PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset)) |
| 21 | +{ |
| 22 | + PyArray_Descr *ldescr = given_descrs[0]; |
| 23 | + PyArray_Descr *rdescr = given_descrs[1]; |
| 24 | + Py_INCREF(ldescr); |
| 25 | + loop_descrs[0] = ldescr; |
| 26 | + Py_INCREF(rdescr); |
| 27 | + loop_descrs[1] = rdescr; |
| 28 | + |
| 29 | + PyArray_Descr *odescr = NULL; |
| 30 | + |
| 31 | + if (dtypes[0] == (PyArray_DTypeMeta *)&StringDType) { |
| 32 | + odescr = ldescr; |
| 33 | + } |
| 34 | + else { |
| 35 | + odescr = rdescr; |
| 36 | + } |
| 37 | + |
| 38 | + loop_descrs[2] = (PyArray_Descr *)new_stringdtype_instance( |
| 39 | + ((StringDTypeObject *)odescr)->na_object); |
| 40 | + |
| 41 | + return NPY_NO_CASTING; |
| 42 | +} |
| 43 | + |
| 44 | +#define MULTIPLY_IMPL(shortname) \ |
| 45 | + static int multiply_loop_core_##shortname( \ |
| 46 | + npy_intp N, char *sin, char *iin, char *out, npy_intp s_stride, \ |
| 47 | + npy_intp i_stride, npy_intp o_stride) \ |
| 48 | + { \ |
| 49 | + ss *is = NULL, *os = NULL; \ |
| 50 | + \ |
| 51 | + while (N--) { \ |
| 52 | + is = (ss *)sin; \ |
| 53 | + npy_##shortname factor = *(npy_##shortname *)iin; \ |
| 54 | + os = (ss *)out; \ |
| 55 | + size_t newlen = (size_t)((is->len) * factor); \ |
| 56 | + \ |
| 57 | + ssfree(os); \ |
| 58 | + if (ssnewemptylen(newlen, os) < 0) { \ |
| 59 | + return -1; \ |
| 60 | + } \ |
| 61 | + \ |
| 62 | + for (size_t i = 0; i < (size_t)factor; i++) { \ |
| 63 | + memcpy(os->buf + i * is->len, is->buf, is->len); \ |
| 64 | + } \ |
| 65 | + os->buf[newlen] = '\0'; \ |
| 66 | + \ |
| 67 | + sin += s_stride; \ |
| 68 | + iin += i_stride; \ |
| 69 | + out += o_stride; \ |
| 70 | + } \ |
| 71 | + return 0; \ |
| 72 | + } \ |
| 73 | + \ |
| 74 | + static int multiply_right_##shortname##_strided_loop( \ |
| 75 | + PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \ |
| 76 | + npy_intp const dimensions[], npy_intp const strides[], \ |
| 77 | + NpyAuxData *NPY_UNUSED(auxdata)) \ |
| 78 | + { \ |
| 79 | + npy_intp N = dimensions[0]; \ |
| 80 | + char *in1 = data[0]; \ |
| 81 | + char *in2 = data[1]; \ |
| 82 | + char *out = data[2]; \ |
| 83 | + npy_intp in1_stride = strides[0]; \ |
| 84 | + npy_intp in2_stride = strides[1]; \ |
| 85 | + npy_intp out_stride = strides[2]; \ |
| 86 | + \ |
| 87 | + return multiply_loop_core_##shortname(N, in1, in2, out, in1_stride, \ |
| 88 | + in2_stride, out_stride); \ |
| 89 | + } \ |
| 90 | + \ |
| 91 | + static int multiply_left_##shortname##_strided_loop( \ |
| 92 | + PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \ |
| 93 | + npy_intp const dimensions[], npy_intp const strides[], \ |
| 94 | + NpyAuxData *NPY_UNUSED(auxdata)) \ |
| 95 | + { \ |
| 96 | + npy_intp N = dimensions[0]; \ |
| 97 | + char *in1 = data[0]; \ |
| 98 | + char *in2 = data[1]; \ |
| 99 | + char *out = data[2]; \ |
| 100 | + npy_intp in1_stride = strides[0]; \ |
| 101 | + npy_intp in2_stride = strides[1]; \ |
| 102 | + npy_intp out_stride = strides[2]; \ |
| 103 | + \ |
| 104 | + return multiply_loop_core_##shortname(N, in2, in1, out, in2_stride, \ |
| 105 | + in1_stride, out_stride); \ |
| 106 | + } |
| 107 | + |
| 108 | +MULTIPLY_IMPL(int8); |
| 109 | +MULTIPLY_IMPL(int16); |
| 110 | +MULTIPLY_IMPL(int32); |
| 111 | +MULTIPLY_IMPL(int64); |
| 112 | +MULTIPLY_IMPL(uint8); |
| 113 | +MULTIPLY_IMPL(uint16); |
| 114 | +MULTIPLY_IMPL(uint32); |
| 115 | +MULTIPLY_IMPL(uint64); |
| 116 | +#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT |
| 117 | +MULTIPLY_IMPL(byte); |
| 118 | +MULTIPLY_IMPL(ubyte); |
| 119 | +#endif |
| 120 | +#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT |
| 121 | +MULTIPLY_IMPL(short); |
| 122 | +MULTIPLY_IMPL(ushort); |
| 123 | +#endif |
| 124 | +#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG |
| 125 | +MULTIPLY_IMPL(long); |
| 126 | +MULTIPLY_IMPL(ulong); |
| 127 | +#endif |
| 128 | +#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG |
| 129 | +MULTIPLY_IMPL(longlong); |
| 130 | +MULTIPLY_IMPL(ulonglong); |
| 131 | +#endif |
| 132 | + |
16 | 133 | static NPY_CASTING
|
17 | 134 | binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
|
18 | 135 | PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
|
@@ -611,6 +728,29 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
|
611 | 728 | return 0;
|
612 | 729 | }
|
613 | 730 |
|
| 731 | +#define INIT_MULTIPLY(typename, shortname) \ |
| 732 | + PyArray_DTypeMeta *multiply_right_##shortname##_types[] = { \ |
| 733 | + (PyArray_DTypeMeta *)&StringDType, &PyArray_##typename##DType, \ |
| 734 | + (PyArray_DTypeMeta *)&StringDType}; \ |
| 735 | + \ |
| 736 | + if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \ |
| 737 | + &multiply_resolve_descriptors, \ |
| 738 | + &multiply_right_##shortname##_strided_loop, \ |
| 739 | + "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \ |
| 740 | + goto error; \ |
| 741 | + } \ |
| 742 | + \ |
| 743 | + PyArray_DTypeMeta *multiply_left_##shortname##_types[] = { \ |
| 744 | + &PyArray_##typename##DType, (PyArray_DTypeMeta *)&StringDType, \ |
| 745 | + (PyArray_DTypeMeta *)&StringDType}; \ |
| 746 | + \ |
| 747 | + if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \ |
| 748 | + &multiply_resolve_descriptors, \ |
| 749 | + &multiply_left_##shortname##_strided_loop, \ |
| 750 | + "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \ |
| 751 | + goto error; \ |
| 752 | + } |
| 753 | + |
614 | 754 | int
|
615 | 755 | init_ufuncs(void)
|
616 | 756 | {
|
@@ -733,6 +873,31 @@ init_ufuncs(void)
|
733 | 873 | goto error;
|
734 | 874 | }
|
735 | 875 |
|
| 876 | + INIT_MULTIPLY(Int8, int8); |
| 877 | + INIT_MULTIPLY(Int16, int16); |
| 878 | + INIT_MULTIPLY(Int32, int32); |
| 879 | + INIT_MULTIPLY(Int64, int64); |
| 880 | + INIT_MULTIPLY(UInt8, uint8); |
| 881 | + INIT_MULTIPLY(UInt16, uint16); |
| 882 | + INIT_MULTIPLY(UInt32, uint32); |
| 883 | + INIT_MULTIPLY(UInt64, uint64); |
| 884 | +#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT |
| 885 | + INIT_MULTIPLY(Byte, byte); |
| 886 | + INIT_MULTIPLY(UByte, ubyte); |
| 887 | +#endif |
| 888 | +#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT |
| 889 | + INIT_MULTIPLY(Short, short); |
| 890 | + INIT_MULTIPLY(UShort, ushort); |
| 891 | +#endif |
| 892 | +#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG |
| 893 | + INIT_MULTIPLY(Long, long); |
| 894 | + INIT_MULTIPLY(ULong, ulong); |
| 895 | +#endif |
| 896 | +#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG |
| 897 | + INIT_MULTIPLY(LongLong, longlong); |
| 898 | + INIT_MULTIPLY(ULongLong, ulonglong); |
| 899 | +#endif |
| 900 | + |
736 | 901 | Py_DECREF(numpy);
|
737 | 902 | return 0;
|
738 | 903 |
|
|
0 commit comments