1010# mypy: disable-error-code=no-any-decorated
1111# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
1212
13-
14- @pytest .mark .parametrize (
13+ param_assert_equal_close = pytest .mark .parametrize (
1514 "func" ,
1615 [
1716 xp_assert_equal ,
2120 ),
2221 ],
2322)
23+
24+
25+ @param_assert_equal_close
2426def test_assert_close_equal_basic (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
2527 func (xp .asarray (0 ), xp .asarray (0 ))
2628 func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
@@ -40,16 +42,7 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
4042
4143@pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
4244@pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "test other ns vs. numpy" )
43- @pytest .mark .parametrize (
44- "func" ,
45- [
46- xp_assert_equal ,
47- pytest .param (
48- xp_assert_close ,
49- marks = pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no isdtype" ),
50- ),
51- ],
52- )
45+ @param_assert_equal_close
5346def test_assert_close_equal_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
5447 with pytest .raises (AssertionError ):
5548 func (xp .asarray (0 ), np .asarray (0 ))
@@ -68,3 +61,30 @@ def test_assert_close_tolerance(xp: ModuleType):
6861 xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 3 )
6962 with pytest .raises (AssertionError ):
7063 xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 1 )
64+
65+
66+ @param_assert_equal_close
67+ @pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "no bool indexing by sparse arrays" )
68+ def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[no-any-explicit]
69+ """On dask and other lazy backends, test that a shape with NaN's or None's
70+ can be compared to a real shape.
71+ """
72+ a = xp .asarray ([1 , 2 ])
73+ a = a [a > 1 ]
74+
75+ func (a , xp .asarray ([2 ]))
76+ with pytest .raises (AssertionError ):
77+ func (a , xp .asarray ([2 , 3 ]))
78+ with pytest .raises (AssertionError ):
79+ func (a , xp .asarray (2 ))
80+ with pytest .raises (AssertionError ):
81+ func (a , xp .asarray ([3 ]))
82+
83+ # Swap actual and desired
84+ func (xp .asarray ([2 ]), a )
85+ with pytest .raises (AssertionError ):
86+ func (xp .asarray ([2 , 3 ]), a )
87+ with pytest .raises (AssertionError ):
88+ func (xp .asarray (2 ), a )
89+ with pytest .raises (AssertionError ):
90+ func (xp .asarray ([3 ]), a )
0 commit comments