Skip to content

Commit 676c2fd

Browse files
authored
Merge pull request #78 from ngoldbaum/add-multiply-ufunc
Add multiply ufunc
2 parents 6fe8cc6 + 630a3cc commit 676c2fd

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

stringdtype/stringdtype/src/umath.c

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,123 @@
1313
#include "string.h"
1414
#include "umath.h"
1515

16+
static NPY_CASTING
17+
multiply_resolve_descriptors(
18+
struct PyArrayMethodObject_tag *NPY_UNUSED(method),
19+
PyArray_DTypeMeta *dtypes[], PyArray_Descr *given_descrs[],
20+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
21+
{
22+
PyArray_Descr *ldescr = given_descrs[0];
23+
PyArray_Descr *rdescr = given_descrs[1];
24+
Py_INCREF(ldescr);
25+
loop_descrs[0] = ldescr;
26+
Py_INCREF(rdescr);
27+
loop_descrs[1] = rdescr;
28+
29+
PyArray_Descr *odescr = NULL;
30+
31+
if (dtypes[0] == (PyArray_DTypeMeta *)&StringDType) {
32+
odescr = ldescr;
33+
}
34+
else {
35+
odescr = rdescr;
36+
}
37+
38+
loop_descrs[2] = (PyArray_Descr *)new_stringdtype_instance(
39+
((StringDTypeObject *)odescr)->na_object);
40+
41+
return NPY_NO_CASTING;
42+
}
43+
44+
#define MULTIPLY_IMPL(shortname) \
45+
static int multiply_loop_core_##shortname( \
46+
npy_intp N, char *sin, char *iin, char *out, npy_intp s_stride, \
47+
npy_intp i_stride, npy_intp o_stride) \
48+
{ \
49+
ss *is = NULL, *os = NULL; \
50+
\
51+
while (N--) { \
52+
is = (ss *)sin; \
53+
npy_##shortname factor = *(npy_##shortname *)iin; \
54+
os = (ss *)out; \
55+
size_t newlen = (size_t)((is->len) * factor); \
56+
\
57+
ssfree(os); \
58+
if (ssnewemptylen(newlen, os) < 0) { \
59+
return -1; \
60+
} \
61+
\
62+
for (size_t i = 0; i < (size_t)factor; i++) { \
63+
memcpy(os->buf + i * is->len, is->buf, is->len); \
64+
} \
65+
os->buf[newlen] = '\0'; \
66+
\
67+
sin += s_stride; \
68+
iin += i_stride; \
69+
out += o_stride; \
70+
} \
71+
return 0; \
72+
} \
73+
\
74+
static int multiply_right_##shortname##_strided_loop( \
75+
PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \
76+
npy_intp const dimensions[], npy_intp const strides[], \
77+
NpyAuxData *NPY_UNUSED(auxdata)) \
78+
{ \
79+
npy_intp N = dimensions[0]; \
80+
char *in1 = data[0]; \
81+
char *in2 = data[1]; \
82+
char *out = data[2]; \
83+
npy_intp in1_stride = strides[0]; \
84+
npy_intp in2_stride = strides[1]; \
85+
npy_intp out_stride = strides[2]; \
86+
\
87+
return multiply_loop_core_##shortname(N, in1, in2, out, in1_stride, \
88+
in2_stride, out_stride); \
89+
} \
90+
\
91+
static int multiply_left_##shortname##_strided_loop( \
92+
PyArrayMethod_Context *NPY_UNUSED(context), char *const data[], \
93+
npy_intp const dimensions[], npy_intp const strides[], \
94+
NpyAuxData *NPY_UNUSED(auxdata)) \
95+
{ \
96+
npy_intp N = dimensions[0]; \
97+
char *in1 = data[0]; \
98+
char *in2 = data[1]; \
99+
char *out = data[2]; \
100+
npy_intp in1_stride = strides[0]; \
101+
npy_intp in2_stride = strides[1]; \
102+
npy_intp out_stride = strides[2]; \
103+
\
104+
return multiply_loop_core_##shortname(N, in2, in1, out, in2_stride, \
105+
in1_stride, out_stride); \
106+
}
107+
108+
MULTIPLY_IMPL(int8);
109+
MULTIPLY_IMPL(int16);
110+
MULTIPLY_IMPL(int32);
111+
MULTIPLY_IMPL(int64);
112+
MULTIPLY_IMPL(uint8);
113+
MULTIPLY_IMPL(uint16);
114+
MULTIPLY_IMPL(uint32);
115+
MULTIPLY_IMPL(uint64);
116+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
117+
MULTIPLY_IMPL(byte);
118+
MULTIPLY_IMPL(ubyte);
119+
#endif
120+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
121+
MULTIPLY_IMPL(short);
122+
MULTIPLY_IMPL(ushort);
123+
#endif
124+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
125+
MULTIPLY_IMPL(long);
126+
MULTIPLY_IMPL(ulong);
127+
#endif
128+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
129+
MULTIPLY_IMPL(longlong);
130+
MULTIPLY_IMPL(ulonglong);
131+
#endif
132+
16133
static NPY_CASTING
17134
binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18135
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
@@ -611,6 +728,29 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
611728
return 0;
612729
}
613730

731+
#define INIT_MULTIPLY(typename, shortname) \
732+
PyArray_DTypeMeta *multiply_right_##shortname##_types[] = { \
733+
(PyArray_DTypeMeta *)&StringDType, &PyArray_##typename##DType, \
734+
(PyArray_DTypeMeta *)&StringDType}; \
735+
\
736+
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
737+
&multiply_resolve_descriptors, \
738+
&multiply_right_##shortname##_strided_loop, \
739+
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
740+
goto error; \
741+
} \
742+
\
743+
PyArray_DTypeMeta *multiply_left_##shortname##_types[] = { \
744+
&PyArray_##typename##DType, (PyArray_DTypeMeta *)&StringDType, \
745+
(PyArray_DTypeMeta *)&StringDType}; \
746+
\
747+
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
748+
&multiply_resolve_descriptors, \
749+
&multiply_left_##shortname##_strided_loop, \
750+
"string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
751+
goto error; \
752+
}
753+
614754
int
615755
init_ufuncs(void)
616756
{
@@ -733,6 +873,31 @@ init_ufuncs(void)
733873
goto error;
734874
}
735875

876+
INIT_MULTIPLY(Int8, int8);
877+
INIT_MULTIPLY(Int16, int16);
878+
INIT_MULTIPLY(Int32, int32);
879+
INIT_MULTIPLY(Int64, int64);
880+
INIT_MULTIPLY(UInt8, uint8);
881+
INIT_MULTIPLY(UInt16, uint16);
882+
INIT_MULTIPLY(UInt32, uint32);
883+
INIT_MULTIPLY(UInt64, uint64);
884+
#if NPY_SIZEOF_BYTE == NPY_SIZEOF_SHORT
885+
INIT_MULTIPLY(Byte, byte);
886+
INIT_MULTIPLY(UByte, ubyte);
887+
#endif
888+
#if NPY_SIZEOF_SHORT == NPY_SIZEOF_INT
889+
INIT_MULTIPLY(Short, short);
890+
INIT_MULTIPLY(UShort, ushort);
891+
#endif
892+
#if NPY_SIZEOF_INT == NPY_SIZEOF_LONG
893+
INIT_MULTIPLY(Long, long);
894+
INIT_MULTIPLY(ULong, ulong);
895+
#endif
896+
#if NPY_SIZEOF_LONGLONG == NPY_SIZEOF_LONG
897+
INIT_MULTIPLY(LongLong, longlong);
898+
INIT_MULTIPLY(ULongLong, ulonglong);
899+
#endif
900+
736901
Py_DECREF(numpy);
737902
return 0;
738903

stringdtype/tests/test_stringdtype.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,46 @@ def test_ufunc_add(dtype, string_list, other_strings):
466466
)
467467

468468

469+
@pytest.mark.parametrize("other", [2, [2, 1, 3, 4, 1, 3]])
470+
@pytest.mark.parametrize(
471+
"other_dtype",
472+
[
473+
"int8",
474+
"int16",
475+
"int32",
476+
"int64",
477+
"uint8",
478+
"uint16",
479+
"uint32",
480+
"uint64",
481+
"short",
482+
"int",
483+
"intp",
484+
"long",
485+
"longlong",
486+
"ushort",
487+
"uint",
488+
"uintp",
489+
"ulong",
490+
"ulonglong",
491+
],
492+
)
493+
def test_ufunc_multiply(dtype, string_list, other, other_dtype):
494+
"""Test the two-argument ufuncs match python builtin behavior."""
495+
arr = np.array(string_list, dtype=StringDType())
496+
other_dtype = np.dtype(other_dtype)
497+
try:
498+
len(other)
499+
result = [s * o for s, o in zip(string_list, other)]
500+
other = np.array(other, dtype=other_dtype)
501+
except TypeError:
502+
other = other_dtype.type(other)
503+
result = [s * other for s in string_list]
504+
505+
np.testing.assert_array_equal(arr * other, result)
506+
np.testing.assert_array_equal(other * arr, result)
507+
508+
469509
def test_create_with_na(dtype):
470510
na_val = dtype.na_object
471511
string_list = ["hello", na_val, "world"]

0 commit comments

Comments
 (0)