Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,49 @@ def test_searchsorted(data):
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise


### @pytest.mark.min_version("2025.12")
@given(data=st.data())
def test_searchsorted_with_scalars(data):
# 1. draw x1, sorter and side exactly the same as in test_searchsorted
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
_x1 = data.draw(
st.lists(
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
min_size=1,
unique=True
),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=x1_dtype)
if data.draw(st.booleans(), label="use sorter?"):
sorter = xp.argsort(x1)
else:
sorter = None
x1 = xp.sort(x1)

kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))

# 2. draw x2, a real-valued scalar
# TODO: draw x2 of promotion compatible dtype (int for float x1 etc) -- cf gh-364
x2 = data.draw(hh.scalars(st.just(x1.dtype), finite=True))

# 3. testing: similar to test_searchsorted, modulo `out.shape == ()`
repro_snippet = ph.format_snippet(
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }"
)
try:
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)

ph.assert_dtype(
"searchsorted",
in_dtype=[x1.dtype], #, x2.dtype
out_dtype=out.dtype,
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: values testing
ph.assert_shape("searchsorted", out_shape=out.shape, expected=())
except Exception as exc:
ph.add_note(exc, repro_snippet)
raise
Loading