Skip to content

Commit d2e8ce2

Browse files
peytondmurrayngoldbaum
authored andcommitted
Add a multiply ufunc
1 parent 6fe8cc6 commit d2e8ce2

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,62 @@
1313
#include "string.h"
1414
#include "umath.h"
1515

16+
static NPY_CASTING
17+
multiply_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
19+
PyArray_Descr *given_descrs[],
20+
PyArray_Descr *loop_descrs[],
21+
npy_intp *NPY_UNUSED(view_offset))
22+
{
23+
Py_INCREF(given_descrs[0]);
24+
loop_descrs[0] = given_descrs[0];
25+
Py_INCREF(given_descrs[1]);
26+
loop_descrs[1] = given_descrs[1];
27+
Py_INCREF(given_descrs[0]);
28+
loop_descrs[2] = given_descrs[0];
29+
30+
return NPY_NO_CASTING;
31+
}
32+
33+
static int
34+
multiply_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
35+
char *const data[], npy_intp const dimensions[],
36+
npy_intp const strides[],
37+
NpyAuxData *NPY_UNUSED(auxdata))
38+
{
39+
npy_intp N = dimensions[0];
40+
char *in1 = data[0];
41+
char *in2 = data[1];
42+
char *out = data[2];
43+
npy_intp in1_stride = strides[0];
44+
npy_intp in2_stride = strides[1];
45+
npy_intp out_stride = strides[2];
46+
47+
ss *s1 = NULL, *os = NULL;
48+
49+
while (N--) {
50+
s1 = (ss *)in1;
51+
npy_int64 factor = *(npy_int64 *)in2;
52+
os = (ss *)out;
53+
npy_int64 newlen = (s1->len) * factor;
54+
55+
ssfree(os);
56+
if (ssnewemptylen(newlen, os) < 0) {
57+
return -1;
58+
}
59+
60+
for (int i = 0; i < factor; i++) {
61+
memcpy(os->buf + i * s1->len, s1->buf, s1->len);
62+
}
63+
os->buf[newlen] = '\0';
64+
65+
in1 += in1_stride;
66+
in2 += in2_stride;
67+
out += out_stride;
68+
}
69+
return 0;
70+
}
71+
1672
static NPY_CASTING
1773
binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
1874
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
@@ -733,6 +789,18 @@ init_ufuncs(void)
733789
goto error;
734790
}
735791

792+
PyArray_DTypeMeta *multiply_types[] = {
793+
(PyArray_DTypeMeta *)&StringDType,
794+
&PyArray_Int64DType,
795+
(PyArray_DTypeMeta *)&StringDType
796+
};
797+
798+
if (init_ufunc(numpy, "multiply", multiply_types,
799+
&multiply_resolve_descriptors, &multiply_strided_loop,
800+
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) {
801+
goto error;
802+
}
803+
736804
Py_DECREF(numpy);
737805
return 0;
738806

stringdtype/tests/test_stringdtype.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,23 @@ def test_ufuncs_minmax(dtype, string_list, ufunc, func):
450450

451451

452452
@pytest.mark.parametrize(
453-
"other_strings",
453+
"ufunc,other,func",
454454
[
455-
["abc", "def", "ghi", "🤣", "📵", "😰"],
456-
["🚜", "🙃", "😾", "😹", "🚠", "🚌"],
457-
["🥦", "¨", "⨯", "∰ ", "⨌ ", "⎶ "],
455+
("add", "asdf", lambda arr, b: [x + b for x in arr]),
456+
(
457+
"add",
458+
np.array(["a", "b", "c", "d", "e", "f"], dtype=StringDType()),
459+
lambda arr1, arr2: [x + y for x, y in zip(arr1, arr2)],
460+
),
461+
("multiply", 3, lambda arr, b: [x * b for x in arr]),
458462
],
459463
)
460-
def test_ufunc_add(dtype, string_list, other_strings):
461-
arr1 = np.array(string_list, dtype=dtype)
462-
arr2 = np.array(other_strings, dtype=dtype)
464+
def test_binary_ufuncs(dtype, string_list, ufunc, other, func):
465+
"""Test the two-argument ufuncs match python builtin behavior."""
466+
arr = np.array(string_list, dtype=StringDType())
463467
np.testing.assert_array_equal(
464-
np.add(arr1, arr2),
465-
np.array([a + b for a, b in zip(arr1, arr2)], dtype=dtype),
468+
getattr(np, ufunc)(arr, other),
469+
np.array(func(string_list, other), dtype=StringDType()),
466470
)
467471

468472

0 commit comments

Comments
 (0)