Skip to content

Commit 0bd2328

Browse files
authored
Merge pull request #47 from peytondmurray/pyarray-arrfunc-argmax
Add `argmax` `PyArray_ArrFunc`
2 parents b38c791 + cb10af0 commit 0bd2328

File tree

3 files changed

+56
-10
lines changed

3 files changed

+56
-10
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,11 @@ repos:
7373
- id: check-added-large-files
7474
- id: check-ast
7575
- repo: https://github.com/charliermarsh/ruff-pre-commit
76-
rev: v0.0.217
76+
rev: v0.0.254
7777
hooks:
7878
- id: ruff
79-
# Respect `exclude` and `extend-exclude` settings.
80-
args: ["--force-exclude"]
8179
- repo: https://github.com/pre-commit/mirrors-prettier
82-
rev: v3.0.0-alpha.4
80+
rev: v3.0.0-alpha.6
8381
hooks:
8482
- id: prettier
8583
types:
@@ -88,7 +86,7 @@ repos:
8886
yaml,
8987
]
9088
- repo: https://github.com/pycqa/isort
91-
rev: 5.11.4
89+
rev: 5.12.0
9290
hooks:
9391
- id: isort
9492
name: isort (python)
@@ -99,7 +97,7 @@ repos:
9997
name: isort (pyi)
10098
types: [pyi]
10199
- repo: https://github.com/psf/black
102-
rev: 22.12.0
100+
rev: 23.1.0
103101
hooks:
104102
- id: black
105103
name: 'black for asciidtype'

stringdtype/stringdtype/src/dtype.c

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,30 @@ nonzero(void *data, void *NPY_UNUSED(arr))
179179
// Implementation of PyArray_CompareFunc.
180180
// Compares unicode strings by their code points.
181181
int
182-
compare_strings(char **a, char **b, PyArrayObject *NPY_UNUSED(arr))
182+
compare(void *a, void *b, void *NPY_UNUSED(arr))
183183
{
184-
ss *ss_a = (ss *)a;
185-
ss *ss_b = (ss *)b;
184+
ss *ss_a = NULL;
185+
ss *ss_b = NULL;
186+
load_string(a, &ss_a);
187+
load_string(b, &ss_b);
186188
return strcmp(ss_a->buf, ss_b->buf);
187189
}
188190

191+
// PyArray_ArgFunc
192+
// The max element is the one with the highest unicode code point.
193+
int
194+
argmax(void *data, npy_intp n, npy_intp *max_ind, void *arr)
195+
{
196+
ss *dptr = (ss *)data;
197+
*max_ind = 0;
198+
for (int i = 1; i < n; i++) {
199+
if (compare(&dptr[i], &dptr[*max_ind], arr) > 0) {
200+
*max_ind = i;
201+
}
202+
}
203+
return 0;
204+
}
205+
189206
static StringDTypeObject *
190207
stringdtype_ensure_canonical(StringDTypeObject *self)
191208
{
@@ -232,8 +249,9 @@ static PyType_Slot StringDType_Slots[] = {
232249
{NPY_DT_setitem, &stringdtype_setitem},
233250
{NPY_DT_getitem, &stringdtype_getitem},
234251
{NPY_DT_ensure_canonical, &stringdtype_ensure_canonical},
235-
{NPY_DT_PyArray_ArrFuncs_compare, &compare_strings},
236252
{NPY_DT_PyArray_ArrFuncs_nonzero, &nonzero},
253+
{NPY_DT_PyArray_ArrFuncs_compare, &compare},
254+
{NPY_DT_PyArray_ArrFuncs_argmax, &argmax},
237255
{NPY_DT_get_clear_loop, &stringdtype_get_clear_loop},
238256
{0, NULL}};
239257

stringdtype/tests/test_stringdtype.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,33 @@ def test_creation_functions():
212212

213213
def test_is_numeric():
214214
assert not StringDType._is_numeric
215+
216+
217+
@pytest.mark.parametrize(
218+
"strings",
219+
[
220+
["left", "right", "leftovers", "righty", "up", "down"],
221+
["🤣🤣", "🤣", "📵", "😰"],
222+
["🚜", "🙃", "😾"],
223+
["😹", "🚠", "🚌"],
224+
["A¢☃€ 😊", " A☃€¢😊", "☃€😊 A¢", "😊☃A¢ €"],
225+
],
226+
)
227+
def test_argmax(strings):
228+
"""Test that argmax matches what python calculates as the argmax."""
229+
arr = np.array(strings, dtype=StringDType())
230+
assert np.argmax(arr) == strings.index(max(strings))
231+
232+
233+
@pytest.mark.parametrize(
234+
"arrfunc,expected",
235+
[
236+
[np.sort, np.empty(10, dtype=StringDType())],
237+
[np.nonzero, (np.array([], dtype=np.int64),)],
238+
[np.argmax, 0],
239+
],
240+
)
241+
def test_arrfuncs_empty(arrfunc, expected):
242+
arr = np.empty(10, dtype=StringDType())
243+
result = arrfunc(arr)
244+
np.testing.assert_array_equal(result, expected, strict=True)

0 commit comments

Comments
 (0)