Skip to content

Commit 6e04b14

Browse files
authored
Merge pull request #48 from peytondmurray/pyarray-arrfunc-argmin
Add argmin PyArray_ArrFunc
2 parents a5e9efc + aaa89f0 commit 6e04b14

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

stringdtype/stringdtype/src/dtype.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,21 @@ argmax(void *data, npy_intp n, npy_intp *max_ind, void *arr)
203203
return 0;
204204
}
205205

206+
// PyArray_ArgFunc
207+
// The min element is the one with the lowest unicode code point.
208+
int
209+
argmin(void *data, npy_intp n, npy_intp *min_ind, void *arr)
210+
{
211+
ss *dptr = (ss *)data;
212+
*min_ind = 0;
213+
for (int i = 1; i < n; i++) {
214+
if (compare(&dptr[i], &dptr[*min_ind], arr) < 0) {
215+
*min_ind = i;
216+
}
217+
}
218+
return 0;
219+
}
220+
206221
static StringDTypeObject *
207222
stringdtype_ensure_canonical(StringDTypeObject *self)
208223
{
@@ -252,6 +267,7 @@ static PyType_Slot StringDType_Slots[] = {
252267
{NPY_DT_PyArray_ArrFuncs_nonzero, &nonzero},
253268
{NPY_DT_PyArray_ArrFuncs_compare, &compare},
254269
{NPY_DT_PyArray_ArrFuncs_argmax, &argmax},
270+
{NPY_DT_PyArray_ArrFuncs_argmin, &argmin},
255271
{NPY_DT_get_clear_loop, &stringdtype_get_clear_loop},
256272
{0, NULL}};
257273

stringdtype/tests/test_stringdtype.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,10 @@ def test_is_numeric():
225225
],
226226
)
227227
def test_argmax(strings):
228-
"""Test that argmax matches what python calculates as the argmax."""
228+
"""Test that argmax/argmin matches what python calculates."""
229229
arr = np.array(strings, dtype=StringDType())
230230
assert np.argmax(arr) == strings.index(max(strings))
231+
assert np.argmin(arr) == strings.index(min(strings))
231232

232233

233234
@pytest.mark.parametrize(
@@ -236,14 +237,15 @@ def test_argmax(strings):
236237
[np.sort, np.empty(10, dtype=StringDType())],
237238
[np.nonzero, (np.array([], dtype=np.int64),)],
238239
[np.argmax, 0],
240+
[np.argmin, 0],
239241
],
240242
)
241243
def test_arrfuncs_empty(arrfunc, expected):
242244
arr = np.empty(10, dtype=StringDType())
243245
result = arrfunc(arr)
244246
np.testing.assert_array_equal(result, expected, strict=True)
245247

246-
248+
247249
@pytest.mark.parametrize(
248250
("string_list", "cast_answer", "any_answer", "all_answer"),
249251
[

0 commit comments

Comments
 (0)