Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,12 @@ def isclose(
atol = int(atol)
if rtol == 0:
return xp.abs(a - b) <= atol
nrtol = int(1.0 / rtol)

try:
nrtol = xp.asarray(int(1.0 / rtol), dtype=b.dtype)
except OverflowError:
return xp.abs(a - b) <= atol

return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)


Expand Down
7 changes: 7 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ def test_tolerance(self, dtype: str, xp: ModuleType):
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))

@pytest.mark.parametrize("dtype", ["int8", "uint8"])
def test_tolerance_integer_overflow(self, dtype: str, xp: ModuleType):
"""1/rtol is too large for dtype"""
a = xp.asarray([100, 100], dtype=getattr(xp, dtype))
b = xp.asarray([100, 101], dtype=getattr(xp, dtype))
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

def test_very_small_numbers(self, xp: ModuleType):
a = xp.asarray([1e-9, 1e-9])
b = xp.asarray([1.0001e-9, 1.00001e-9])
Expand Down