Skip to content

Commit 993cc5b

Browse files
committed
ENH: searchsorted: allow python scalars for x2
cross-ref data-apis/array-api#982
1 parent be9dd9e commit 993cc5b

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
454454
"""
455455
Strategy to generate a scalar that matches a dtype strategy
456456
457-
dtypes should be one of the shared_* dtypes strategies.
457+
dtypes should be one of the shared_* dtypes strategies or a sequence of dtypes.
458458
"""
459-
dtype = draw(dtypes)
459+
if isinstance(dtypes, Sequence):
460+
dtype = draw(sampled_from(dtypes))
461+
else:
462+
dtype = draw(dtypes)
460463
mM = kwds.pop('mM', None)
461464
if dh.is_int_dtype(dtype):
462465
if mM is None:

array_api_tests/test_searching_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,49 @@ def test_searchsorted(data):
291291
except Exception as exc:
292292
ph.add_note(exc, repro_snippet)
293293
raise
294+
295+
296+
### @pytest.mark.min_version("2025.12")
297+
@given(data=st.data())
298+
def test_searchsorted_with_scalars(data):
299+
# 1. draw x1, sorter and side exactly the same as in test_searchsorted
300+
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
301+
_x1 = data.draw(
302+
st.lists(
303+
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
304+
min_size=1,
305+
unique=True
306+
),
307+
label="_x1",
308+
)
309+
x1 = xp.asarray(_x1, dtype=x1_dtype)
310+
if data.draw(st.booleans(), label="use sorter?"):
311+
sorter = xp.argsort(x1)
312+
else:
313+
sorter = None
314+
x1 = xp.sort(x1)
315+
316+
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))
317+
318+
# 2. draw x2, a real-valued scalar
319+
# TODO: draw x2 of promotion compatible dtype (int for float x1 etc) -- cf gh-364
320+
x2 = data.draw(hh.scalars(st.just(x1.dtype), finite=True))
321+
322+
# 3. testing: similar to test_searchsorted, modulo `out.shape == ()`
323+
repro_snippet = ph.format_snippet(
324+
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }"
325+
)
326+
try:
327+
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)
328+
329+
ph.assert_dtype(
330+
"searchsorted",
331+
in_dtype=[x1.dtype], #, x2.dtype
332+
out_dtype=out.dtype,
333+
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
334+
)
335+
# TODO: values testing
336+
ph.assert_shape("searchsorted", out_shape=out.shape, expected=())
337+
except Exception as exc:
338+
ph.add_note(repro_snippet)
339+
raise

0 commit comments

Comments
 (0)