Skip to content

Commit 718fea9

Browse files
authored
Merge pull request #54 from peytondmurray/ufunc-minmax
Add min and max ufuncs
2 parents fae5b53 + 00b0311 commit 718fea9

File tree

4 files changed

+155
-27
lines changed

4 files changed

+155
-27
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,16 @@ string_to_string(PyArrayMethod_Context *NPY_UNUSED(context),
5252
npy_intp in_stride = strides[0];
5353
npy_intp out_stride = strides[1];
5454

55-
ss *s = NULL, *os = NULL;
55+
ss *s = NULL;
5656

5757
while (N--) {
58-
load_string(in, &s);
59-
os = (ss *)out;
60-
ssfree(os);
61-
if (ssdup(s, os) < 0) {
58+
// *out* may be reallocated later; *in->buf* may point to a statically
59+
// allocated empty ss struct, so we need to load the string into an
60+
// intermediate buffer *s* to avoid the possibility of freeing static
61+
// data later on.
62+
load_string(in, (ss **)&s);
63+
ssfree((ss *)out);
64+
if (ssdup((ss *)s, (ss *)out) < 0) {
6265
gil_error(PyExc_MemoryError, "ssdup failed");
6366
return -1;
6467
}

stringdtype/stringdtype/src/dtype.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ new_stringdtype_instance(void);
2626
int
2727
init_string_dtype(void);
2828

29+
int
30+
compare(void *, void *, void *);
31+
2932
// from dtypemeta.h, not public in numpy
3033
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
3134

stringdtype/stringdtype/src/umath.c

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,101 @@
1313
#include "string.h"
1414
#include "umath.h"
1515

16+
static int
17+
minmax_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
18+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[]),
19+
PyArray_Descr *given_descrs[],
20+
PyArray_Descr *loop_descrs[],
21+
npy_intp *NPY_UNUSED(view_offset))
22+
{
23+
Py_INCREF(given_descrs[0]);
24+
loop_descrs[0] = given_descrs[0];
25+
Py_INCREF(given_descrs[1]);
26+
loop_descrs[1] = given_descrs[1];
27+
28+
StringDTypeObject *new = new_stringdtype_instance();
29+
if (new == NULL) {
30+
return -1;
31+
}
32+
loop_descrs[2] = (PyArray_Descr *)new;
33+
34+
return NPY_NO_CASTING;
35+
}
36+
37+
static int
38+
maximum_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
39+
char *const data[], npy_intp const dimensions[],
40+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
41+
{
42+
npy_intp N = dimensions[0];
43+
char *in1 = data[0];
44+
char *in2 = data[1];
45+
char *out = data[2];
46+
npy_intp in1_stride = strides[0];
47+
npy_intp in2_stride = strides[1];
48+
npy_intp out_stride = strides[2];
49+
50+
while (N--) {
51+
if (compare(in1, in2, NULL) > 0) {
52+
// Only copy *out* to *in1* if they point to different locations;
53+
// for *arr.max()* they point to the same address.
54+
if (in1 != out) {
55+
ssfree((ss *)out);
56+
if (ssdup((ss *)in1, (ss *)out) < 0) {
57+
return -1;
58+
}
59+
}
60+
}
61+
else {
62+
ssfree((ss *)out);
63+
if (ssdup((ss *)in2, (ss *)out) < 0) {
64+
return -1;
65+
}
66+
}
67+
in1 += in1_stride;
68+
in2 += in2_stride;
69+
out += out_stride;
70+
}
71+
72+
return 0;
73+
}
74+
75+
static int
76+
minimum_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
77+
char *const data[], npy_intp const dimensions[],
78+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
79+
{
80+
npy_intp N = dimensions[0];
81+
char *in1 = data[0];
82+
char *in2 = data[1];
83+
char *out = data[2];
84+
npy_intp in1_stride = strides[0];
85+
npy_intp in2_stride = strides[1];
86+
npy_intp out_stride = strides[2];
87+
88+
while (N--) {
89+
if (compare(in1, in2, NULL) < 0) {
90+
if (in1 != out) {
91+
ssfree((ss *)out);
92+
if (ssdup((ss *)in1, (ss *)out) < 0) {
93+
return -1;
94+
}
95+
}
96+
}
97+
else {
98+
ssfree((ss *)out);
99+
if (ssdup((ss *)in2, (ss *)out) < 0) {
100+
return -1;
101+
}
102+
}
103+
in1 += in1_stride;
104+
in2 += in2_stride;
105+
out += out_stride;
106+
}
107+
108+
return 0;
109+
}
110+
16111
static int
17112
string_equal_strided_loop(PyArrayMethod_Context *NPY_UNUSED(context),
18113
char *const data[], npy_intp const dimensions[],
@@ -270,6 +365,19 @@ init_ufuncs(void)
270365
goto error;
271366
}
272367

368+
PyArray_DTypeMeta *minmax_dtypes[] = {&StringDType, &StringDType,
369+
&StringDType};
370+
if (init_ufunc(numpy, "maximum", minmax_dtypes,
371+
&minmax_resolve_descriptors, &maximum_strided_loop,
372+
"string_maximum", 2, 1, NPY_NO_CASTING, 0) < 0) {
373+
goto error;
374+
}
375+
if (init_ufunc(numpy, "minimum", minmax_dtypes,
376+
&minmax_resolve_descriptors, &minimum_strided_loop,
377+
"string_minimum", 2, 1, NPY_NO_CASTING, 0) < 0) {
378+
goto error;
379+
}
380+
273381
Py_DECREF(numpy);
274382
return 0;
275383

stringdtype/tests/test_stringdtype.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@pytest.fixture
1414
def string_list():
15-
return ["abc", "def", "ghi"]
15+
return ["abc", "def", "ghi", "A¢☃€ 😊", "Abc", "DEF"]
1616

1717

1818
def test_scalar_creation():
@@ -44,13 +44,7 @@ def test_array_creation_utf8(data):
4444

4545

4646
def test_array_creation_scalars(string_list):
47-
arr = np.array(
48-
[
49-
StringScalar("abc"),
50-
StringScalar("def"),
51-
StringScalar("ghi"),
52-
]
53-
)
47+
arr = np.array([StringScalar(s) for s in string_list])
5448
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
5549

5650

@@ -69,29 +63,30 @@ def test_bad_scalars(data):
6963

7064

7165
@pytest.mark.parametrize(
72-
("string_list"),
66+
("strings"),
7367
[
7468
["this", "is", "an", "array"],
7569
["€", "", "😊"],
7670
["A¢☃€ 😊", " A☃€¢😊", "☃€😊 A¢", "😊☃A¢ €"],
7771
],
7872
)
79-
def test_unicode_casts(string_list):
80-
arr = np.array(string_list, dtype=np.unicode_).astype(StringDType())
81-
expected = np.array(string_list, dtype=StringDType())
73+
def test_unicode_casts(strings):
74+
arr = np.array(strings, dtype=np.unicode_).astype(StringDType())
75+
expected = np.array(strings, dtype=StringDType())
8276
np.testing.assert_array_equal(arr, expected)
8377

84-
arr = np.array(string_list, dtype=StringDType())
78+
arr = np.array(strings, dtype=StringDType())
79+
8580
np.testing.assert_array_equal(
86-
arr.astype("U8"), np.array(string_list, dtype="U8")
81+
arr.astype("U8"), np.array(strings, dtype="U8")
8782
)
8883
np.testing.assert_array_equal(arr.astype("U8").astype(StringDType()), arr)
8984
np.testing.assert_array_equal(
90-
arr.astype("U3"), np.array(string_list, dtype="U3")
85+
arr.astype("U3"), np.array(strings, dtype="U3")
9186
)
9287
np.testing.assert_array_equal(
9388
arr.astype("U3").astype(StringDType()),
94-
np.array([s[:3] for s in string_list], dtype=StringDType()),
89+
np.array([s[:3] for s in strings], dtype=StringDType()),
9590
)
9691

9792

@@ -107,10 +102,14 @@ def test_additional_unicode_cast(string_list):
107102

108103

109104
def test_insert_scalar(string_list):
105+
"""Test that inserting a scalar works."""
110106
dtype = StringDType()
111107
arr = np.array(string_list, dtype=dtype)
112108
arr[1] = StringScalar("what")
113-
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))
109+
np.testing.assert_array_equal(
110+
arr,
111+
np.array(string_list[:1] + ["what"] + string_list[2:], dtype=dtype),
112+
)
114113

115114

116115
def test_equality_promotion(string_list):
@@ -128,8 +127,8 @@ def test_isnan(string_list):
128127
)
129128

130129

131-
def test_memory_usage(string_list):
132-
sarr = np.array(string_list, dtype=StringDType())
130+
def test_memory_usage():
131+
sarr = np.array(["abc", "def", "ghi"], dtype=StringDType())
133132
# 4 bytes for each ASCII string buffer in string_list
134133
# (three characters and null terminator)
135134
# plus enough bytes for the size_t length
@@ -258,16 +257,16 @@ def test_arrfuncs_empty(arrfunc, expected):
258257

259258

260259
@pytest.mark.parametrize(
261-
("string_list", "cast_answer", "any_answer", "all_answer"),
260+
("strings", "cast_answer", "any_answer", "all_answer"),
262261
[
263262
[["hello", "world"], [True, True], True, True],
264263
[["", ""], [False, False], False, False],
265264
[["hello", ""], [True, False], True, False],
266265
[["", "world"], [False, True], True, False],
267266
],
268267
)
269-
def test_bool_cast(string_list, cast_answer, any_answer, all_answer):
270-
sarr = np.array(string_list, dtype=StringDType())
268+
def test_bool_cast(strings, cast_answer, any_answer, all_answer):
269+
sarr = np.array(strings, dtype=StringDType())
271270
np.testing.assert_array_equal(sarr.astype("bool"), cast_answer)
272271

273272
assert np.any(sarr) == any_answer
@@ -285,3 +284,18 @@ def test_take(string_list):
285284
out[0] = "hello"
286285
res = sarr.take(np.arange(len(string_list)), out=out)
287286
np.testing.assert_array_equal(res, out)
287+
288+
289+
@pytest.mark.parametrize(
290+
"ufunc,func",
291+
[
292+
("min", min),
293+
("max", max),
294+
],
295+
)
296+
def test_ufuncs_minmax(string_list, ufunc, func):
297+
"""Test that the min/max ufuncs match Python builtin min/max behavior."""
298+
arr = np.array(string_list, dtype=StringDType())
299+
np.testing.assert_array_equal(
300+
getattr(arr, ufunc)(), np.array(func(string_list), dtype=StringDType())
301+
)

0 commit comments

Comments
 (0)