Skip to content

Commit 630a3cc

Browse files
committed
generalize multiply ufunc to work for all int dtypes
1 parent d2e8ce2 commit 630a3cc

File tree

2 files changed

+205
-72
lines changed

2 files changed

+205
-72
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 156 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,61 +14,122 @@
1414
#include "umath.h"
1515

1616
static NPY_CASTING
17-
multiply_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18-
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
19-
PyArray_Descr *given_descrs[],
20-
PyArray_Descr *loop_descrs[],
21-
npy_intp *NPY_UNUSED(view_offset))
22-
{
23-
Py_INCREF(given_descrs[0]);
24-
loop_descrs[0] = given_descrs[0];
25-
Py_INCREF(given_descrs[1]);
26-
loop_descrs[1] = given_descrs[1];
27-
Py_INCREF(given_descrs[0]);
28-
loop_descrs[2] = given_descrs[0];
29-
30-
return NPY_NO_CASTING;
31-
}
32-
33-
static int
34-
multiply_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
35-
char *const data[], npy_intp const dimensions[],
36-
npy_intp const strides[],
37-
NpyAuxData *NPY_UNUSED(auxdata))
17+
multiply_resolve_descriptors(
18+
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
19+
PyArray_DTypeMeta *dtypes[], PyArray_Descr *given_descrs[],
20+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
3821
{
39-
npy_intp N = dimensions[0];
40-
char *in1 = data[0];
41-
char *in2 = data[1];
42-
char *out = data[2];
43-
npy_intp in1_stride = strides[0];
44-
npy_intp in2_stride = strides[1];
45-
npy_intp out_stride = strides[2];
46-
47-
ss *s1 = NULL, *os = NULL;
22+
PyArray_Descr *ldescr = given_descrs[0];
23+
PyArray_Descr *rdescr = given_descrs[1];
24+
Py_INCREF(ldescr);
25+
loop_descrs[0] = ldescr;
26+
Py_INCREF(rdescr);
27+
loop_descrs[1] = rdescr;
4828

49-
while (N--) {
50-
s1 = (ss *)in1;
51-
npy_int64 factor = *(npy_int64 *)in2;
52-
os = (ss *)out;
53-
npy_int64 newlen = (s1->len) * factor;
29+
PyArray_Descr *odescr = NULL;
5430

55-
ssfree(os);
56-
if (ssnewemptylen(newlen, os) < 0) {
57-
return -1;
58-
}
31+
if (dtypes[0] == (PyArray_DTypeMeta *)&StringDType) {
32+
odescr = ldescr;
33+
}
34+
else {
35+
odescr = rdescr;
36+
}
5937

60-
for (int i = 0; i < factor; i++) {
61-
memcpy(os->buf + i * s1->len, s1->buf, s1->len);
62-
}
63-
os->buf[newlen] = '\0';
38+
loop_descrs[2] = (PyArray_Descr *)new_stringdtype_instance(
39+
((StringDTypeObject *)odescr)->na_object);
6440

65-
in1 += in1_stride;
66-
in2 += in2_stride;
67-
out += out_stride;
68-
}
69-
return 0;
41+
return NPY_NO_CASTING;
7042
}
7143

44+
#define MULTIPLY_IMPL(shortname) \
45+
static int multiply_loop_core_##shortname( \
46+
npy_intp N, char *sin, char *iin, char *out, npy_intp s_stride, \
47+
npy_intp i_stride, npy_intp o_stride) \
48+
{ \
49+
ss *is = NULL, *os = NULL; \
50+
\
51+
while (N--) { \
52+
is = (ss *)sin; \
53+
npy_##shortname factor = *(npy_##shortname *)iin; \
54+
os = (ss *)out; \
55+
size_t newlen = (size_t)((is->len) * factor); \
56+
\
57+
ssfree(os); \
58+
if (ssnewemptylen(newlen, os) < 0) { \
59+
return -1; \
60+
} \
61+
\
62+
for (size_t i = 0; i < (size_t)factor; i++) { \
63+
memcpy(os->buf + i * is->len, is->buf, is->len); \
64+
} \
65+
os->buf[newlen] = '\0'; \
66+
\
67+
sin += s_stride; \
68+
iin += i_stride; \
69+
out += o_stride; \
70+
} \
71+
return 0; \
72+
} \
73+
\
74+
static int multiply_right_##shortname##_strided_loop( \
75+
PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \
76+
npy_intp const dimensions[], npy_intp const strides[], \
77+
NpyAuxData *NPY_UNUSED(auxdata)) \
78+
{ \
79+
npy_intp N = dimensions[0]; \
80+
char *in1 = data[0]; \
81+
char *in2 = data[1]; \
82+
char *out = data[2]; \
83+
npy_intp in1_stride = strides[0]; \
84+
npy_intp in2_stride = strides[1]; \
85+
npy_intp out_stride = strides[2]; \
86+
\
87+
return multiply_loop_core_##shortname(N, in1, in2, out, in1_stride, \
88+
in2_stride, out_stride); \
89+
} \
90+
\
91+
static int multiply_left_##shortname##_strided_loop( \
92+
PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \
93+
npy_intp const dimensions[], npy_intp const strides[], \
94+
NpyAuxData *NPY_UNUSED(auxdata)) \
95+
{ \
96+
npy_intp N = dimensions[0]; \
97+
char *in1 = data[0]; \
98+
char *in2 = data[1]; \
99+
char *out = data[2]; \
100+
npy_intp in1_stride = strides[0]; \
101+
npy_intp in2_stride = strides[1]; \
102+
npy_intp out_stride = strides[2]; \
103+
\
104+
return multiply_loop_core_##shortname(N, in2, in1, out, in2_stride, \
105+
in1_stride, out_stride); \
106+
}
107+
108+
MULTIPLY_IMPL(int8);
109+
MULTIPLY_IMPL(int16);
110+
MULTIPLY_IMPL(int32);
111+
MULTIPLY_IMPL(int64);
112+
MULTIPLY_IMPL(uint8);
113+
MULTIPLY_IMPL(uint16);
114+
MULTIPLY_IMPL(uint32);
115+
MULTIPLY_IMPL(uint64);
116+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
117+
MULTIPLY_IMPL(byte);
118+
MULTIPLY_IMPL(ubyte);
119+
#endif
120+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
121+
MULTIPLY_IMPL(short);
122+
MULTIPLY_IMPL(ushort);
123+
#endif
124+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
125+
MULTIPLY_IMPL(long);
126+
MULTIPLY_IMPL(ulong);
127+
#endif
128+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
129+
MULTIPLY_IMPL(longlong);
130+
MULTIPLY_IMPL(ulonglong);
131+
#endif
132+
72133
static NPY_CASTING
73134
binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
74135
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
@@ -667,6 +728,29 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
667728
return 0;
668729
}
669730

731+
#define INIT_MULTIPLY(typename, shortname) \
732+
PyArray_DTypeMeta *multiply_right_##shortname##_types[] = { \
733+
(PyArray_DTypeMeta *)&StringDType, &PyArray_##typename##DType, \
734+
(PyArray_DTypeMeta *)&StringDType}; \
735+
\
736+
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
737+
&multiply_resolve_descriptors, \
738+
&multiply_right_##shortname##_strided_loop, \
739+
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
740+
goto error; \
741+
} \
742+
\
743+
PyArray_DTypeMeta *multiply_left_##shortname##_types[] = { \
744+
&PyArray_##typename##DType, (PyArray_DTypeMeta *)&StringDType, \
745+
(PyArray_DTypeMeta *)&StringDType}; \
746+
\
747+
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
748+
&multiply_resolve_descriptors, \
749+
&multiply_left_##shortname##_strided_loop, \
750+
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
751+
goto error; \
752+
}
753+
670754
int
671755
init_ufuncs(void)
672756
{
@@ -789,17 +873,30 @@ init_ufuncs(void)
789873
goto error;
790874
}
791875

792-
PyArray_DTypeMeta *multiply_types[] = {
793-
(PyArray_DTypeMeta *)&StringDType,
794-
&PyArray_Int64DType,
795-
(PyArray_DTypeMeta *)&StringDType
796-
};
797-
798-
if (init_ufunc(numpy, "multiply", multiply_types,
799-
&multiply_resolve_descriptors, &multiply_strided_loop,
800-
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) {
801-
goto error;
802-
}
876+
INIT_MULTIPLY(Int8, int8);
877+
INIT_MULTIPLY(Int16, int16);
878+
INIT_MULTIPLY(Int32, int32);
879+
INIT_MULTIPLY(Int64, int64);
880+
INIT_MULTIPLY(UInt8, uint8);
881+
INIT_MULTIPLY(UInt16, uint16);
882+
INIT_MULTIPLY(UInt32, uint32);
883+
INIT_MULTIPLY(UInt64, uint64);
884+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
885+
INIT_MULTIPLY(Byte, byte);
886+
INIT_MULTIPLY(UByte, ubyte);
887+
#endif
888+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
889+
INIT_MULTIPLY(Short, short);
890+
INIT_MULTIPLY(UShort, ushort);
891+
#endif
892+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
893+
INIT_MULTIPLY(Long, long);
894+
INIT_MULTIPLY(ULong, ulong);
895+
#endif
896+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
897+
INIT_MULTIPLY(LongLong, longlong);
898+
INIT_MULTIPLY(ULongLong, ulonglong);
899+
#endif
803900

804901
Py_DECREF(numpy);
805902
return 0;

stringdtype/tests/test_stringdtype.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -450,26 +450,62 @@ def test_ufuncs_minmax(dtype, string_list, ufunc, func):
450450

451451

452452
@pytest.mark.parametrize(
453-
"ufunc,other,func",
453+
"other_strings",
454454
[
455-
("add", "asdf", lambda arr, b: [x + b for x in arr]),
456-
(
457-
"add",
458-
np.array(["a", "b", "c", "d", "e", "f"], dtype=StringDType()),
459-
lambda arr1, arr2: [x + y for x, y in zip(arr1, arr2)],
460-
),
461-
("multiply", 3, lambda arr, b: [x * b for x in arr]),
455+
["abc", "def", "ghi", "🤣", "📵", "😰"],
456+
["🚜", "🙃", "😾", "😹", "🚠", "🚌"],
457+
["🥦", "¨", "⨯", "∰ ", "⨌ ", "⎶ "],
462458
],
463459
)
464-
def test_binary_ufuncs(dtype, string_list, ufunc, other, func):
465-
"""Test the two-argument ufuncs match python builtin behavior."""
466-
arr = np.array(string_list, dtype=StringDType())
460+
def test_ufunc_add(dtype, string_list, other_strings):
461+
arr1 = np.array(string_list, dtype=dtype)
462+
arr2 = np.array(other_strings, dtype=dtype)
467463
np.testing.assert_array_equal(
468-
getattr(np, ufunc)(arr, other),
469-
np.array(func(string_list, other), dtype=StringDType()),
464+
np.add(arr1, arr2),
465+
np.array([a + b for a, b in zip(arr1, arr2)], dtype=dtype),
470466
)
471467

472468

469+
@pytest.mark.parametrize("other", [2, [2, 1, 3, 4, 1, 3]])
470+
@pytest.mark.parametrize(
471+
"other_dtype",
472+
[
473+
"int8",
474+
"int16",
475+
"int32",
476+
"int64",
477+
"uint8",
478+
"uint16",
479+
"uint32",
480+
"uint64",
481+
"short",
482+
"int",
483+
"intp",
484+
"long",
485+
"longlong",
486+
"ushort",
487+
"uint",
488+
"uintp",
489+
"ulong",
490+
"ulonglong",
491+
],
492+
)
493+
def test_ufunc_multiply(dtype, string_list, other, other_dtype):
494+
"""Test the two-argument ufuncs match python builtin behavior."""
495+
arr = np.array(string_list, dtype=StringDType())
496+
other_dtype = np.dtype(other_dtype)
497+
try:
498+
len(other)
499+
result = [s * o for s, o in zip(string_list, other)]
500+
other = np.array(other, dtype=other_dtype)
501+
except TypeError:
502+
other = other_dtype.type(other)
503+
result = [s * other for s in string_list]
504+
505+
np.testing.assert_array_equal(arr * other, result)
506+
np.testing.assert_array_equal(other * arr, result)
507+
508+
473509
def test_create_with_na(dtype):
474510
na_val = dtype.na_object
475511
string_list = ["hello", na_val, "world"]

0 commit comments

Comments
 (0)