Skip to content

Commit 617778d

Browse files
authored
Fix 1D return dtype in argmax/argmin (#704)
1 parent 3c21dfe commit 617778d

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ array_api_tests/test_indexing_functions.py::test_take
4444
array_api_tests/test_linalg.py::test_vecdot
4545
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
4646
array_api_tests/test_operators_and_elementwise_functions.py::test_trunc
47-
array_api_tests/test_searching_functions.py::test_argmax
48-
array_api_tests/test_searching_functions.py::test_argmin
4947
array_api_tests/test_set_functions.py::test_unique_all
5048
array_api_tests/test_set_functions.py::test_unique_inverse
5149
array_api_tests/test_signatures.py::test_func_signature[unique_all]

sparse/numba_backend/_coo/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def _compute_minmax_args(
14561456
if not found:
14571457
result_data.append(current_coord + 1)
14581458

1459-
return (result_indices, result_data)
1459+
return (result_indices, np.array(result_data, dtype=np.intp))
14601460

14611461

14621462
def _arg_minmax_common(

0 commit comments

Comments
 (0)