|
48 | 48 | lazy_xp_function(sinc, static_argnames="xp") |
49 | 49 |
|
50 | 50 |
|
51 | | -NUMPY_GE2 = int(np.__version__.split(".")[0]) >= 2 |
| 51 | +NUMPY_VERSION = tuple(int(v) for v in np.__version__.split(".")[2]) |
52 | 52 |
|
53 | 53 |
|
54 | 54 | class TestApplyWhere: |
@@ -224,7 +224,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any] |
224 | 224 | ): |
225 | 225 | if ( |
226 | 226 | library in (Backend.NUMPY, Backend.NUMPY_READONLY) |
227 | | - and not NUMPY_GE2 |
| 227 | + and NUMPY_VERSION < (2, 0) |
228 | 228 | and dtype is np.float32 |
229 | 229 | ): |
230 | 230 | pytest.xfail(reason="NumPy 1.x dtype promotion for scalars") |
@@ -842,7 +842,11 @@ def test_all_equal(self, xp: ModuleType): |
842 | 842 | @pytest.mark.xfail_xp_backend( |
843 | 843 | Backend.SPARSE, reason="Non-compliant equal_nan=True behaviour" |
844 | 844 | ) |
845 | | - def test_nan(self, xp: ModuleType): |
| 845 | + def test_nan(self, xp: ModuleType, library: Backend): |
| 846 | + is_numpy = library in (Backend.NUMPY, Backend.NUMPY_READONLY) |
| 847 | + if is_numpy and NUMPY_VERSION < (1, 24): |
| 848 | + pytest.xfail("NumPy <1.24 has no equal_nan kwarg in unique") |
| 849 | + |
846 | 850 | # Each NaN is counted separately |
847 | 851 | a = xp.asarray([xp.nan, 123.0, xp.nan]) |
848 | 852 | xp_assert_equal(nunique(a), xp.asarray(3)) |
|
0 commit comments