Skip to content

Commit b8db551

Browse files
committed
Make numpy.array_api dtypes issue a warning when compared against numpy dtypes
This is to prevent user error, since something like numpy.array_api.float32 == numpy.float32 gives False. Original NumPy Commit: 595342c933b5db00a9baddbc142676448ffe8228
1 parent d577d46 commit b8db551

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

array_api_strict/_dtypes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24

35
# Note: we wrap the NumPy dtype objects in a bare class, so that none of the
@@ -12,6 +14,15 @@ def __repr__(self):
1214
return f"np.array_api.{self._np_dtype.name}"
1315

1416
def __eq__(self, other):
17+
# See https://github.com/numpy/numpy/pull/25370/files#r1423259515.
18+
# Avoid the user error of array_api_strict.float32 == numpy.float32,
19+
# which gives False. Making == error is probably too egregious, so
20+
# warn instead.
21+
if isinstance(other, np.dtype) or (isinstance(other, type) and issubclass(other, np.generic)):
22+
warnings.warn("""You are comparing a array_api_strict dtype against \
23+
a NumPy native dtype object, but you probably don't want to do this. \
24+
array_api_strict dtype objects compare unequal to their NumPy equivalents. Such \
25+
cross-library comparison is not supported by the standard.""")
1526
if not isinstance(other, _DType):
1627
return NotImplemented
1728
return self._np_dtype == other._np_dtype

array_api_strict/tests/test_creation_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from numpy.testing import assert_raises
24
import numpy as np
35

@@ -25,7 +27,11 @@ def test_asarray_errors():
2527
# Test various protections against incorrect usage
2628
assert_raises(TypeError, lambda: Array([1]))
2729
assert_raises(TypeError, lambda: asarray(["a"]))
28-
assert_raises(ValueError, lambda: asarray([1.0], dtype=np.float16))
30+
with assert_raises(ValueError), warnings.catch_warnings(record=True) as w:
31+
warnings.simplefilter("always")
32+
asarray([1.0], dtype=np.float16)
33+
assert len(w) == 1
34+
assert issubclass(w[-1].category, UserWarning)
2935
assert_raises(OverflowError, lambda: asarray(2**100))
3036
# Preferably this would be OverflowError
3137
# assert_raises(OverflowError, lambda: asarray([2**100]))

array_api_strict/tests/test_data_type_functions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import pytest
24

35
from numpy.testing import assert_raises
@@ -24,7 +26,15 @@ def test_isdtype_strictness():
2426
assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8'))
2527

2628
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),)))
27-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.object_))
29+
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
30+
warnings.simplefilter("always")
31+
xp.isdtype(xp.float64, np.object_)
32+
assert len(w) == 1
33+
assert issubclass(w[-1].category, UserWarning)
2834

2935
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None))
30-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.float64))
36+
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
37+
warnings.simplefilter("always")
38+
xp.isdtype(xp.float64, np.float64)
39+
assert len(w) == 1
40+
assert issubclass(w[-1].category, UserWarning)

0 commit comments

Comments
 (0)