66import pytest
77
88from array_api_extra import (
9- allclose ,
109 at ,
1110 atleast_nd ,
1211 cov ,
@@ -291,7 +290,6 @@ def test_basic(self, a: float, b: float, xp: ModuleType):
291290 b_xp = xp .asarray (b )
292291
293292 xp_assert_equal (isclose (a_xp , b_xp ), xp .asarray (np .isclose (a , b )))
294- xp_assert_equal (allclose (a_xp , b_xp ), xp .asarray (np .allclose (a , b )))
295293
296294 with warnings .catch_warnings ():
297295 warnings .simplefilter ("ignore" )
@@ -328,8 +326,6 @@ def test_equal_nan(self, xp: ModuleType):
328326 b = xp .asarray ([float ("nan" ), 1.0 , float ("nan" )])
329327 xp_assert_equal (isclose (a , b ), xp .asarray ([False , False , False ]))
330328 xp_assert_equal (isclose (a , b , equal_nan = True ), xp .asarray ([True , False , False ]))
331- xp_assert_equal (allclose (a [:1 ], b [:1 ]), xp .asarray (False ))
332- xp_assert_equal (allclose (a [:1 ], b [:1 ], equal_nan = True ), xp .asarray (True ))
333329
334330 @pytest .mark .parametrize ("dtype" , ["float32" , "complex64" , "int32" ])
335331 def test_tolerance (self , dtype : str , xp : ModuleType ):
@@ -339,15 +335,10 @@ def test_tolerance(self, dtype: str, xp: ModuleType):
339335 xp_assert_equal (isclose (a , b ), xp .asarray ([False , False ]))
340336 xp_assert_equal (isclose (a , b , atol = 1 ), xp .asarray ([True , False ]))
341337 xp_assert_equal (isclose (a , b , rtol = 0.01 ), xp .asarray ([True , False ]))
342- xp_assert_equal (allclose (a [:1 ], b [:1 ]), xp .asarray (False ))
343- xp_assert_equal (allclose (a [:1 ], b [:1 ], atol = 1 ), xp .asarray (True ))
344- xp_assert_equal (allclose (a [:1 ], b [:1 ], rtol = 0.01 ), xp .asarray (True ))
345338
346339 # Attempt to trigger division by 0 in rtol on int dtype
347340 xp_assert_equal (isclose (a , b , rtol = 0 ), xp .asarray ([False , False ]))
348341 xp_assert_equal (isclose (a , b , atol = 1 , rtol = 0 ), xp .asarray ([True , False ]))
349- xp_assert_equal (allclose (a [:1 ], b [:1 ], rtol = 0 ), xp .asarray (False ))
350- xp_assert_equal (allclose (a [:1 ], b [:1 ], atol = 1 , rtol = 0 ), xp .asarray (True ))
351342
352343 def test_very_small_numbers (self , xp : ModuleType ):
353344 a = xp .asarray ([1e-9 , 1e-9 ])
@@ -367,12 +358,6 @@ def test_bool_dtype(self, xp: ModuleType):
367358 xp_assert_equal (isclose (a , b , rtol = 1 ), xp .asarray ([True , True , True ]))
368359 xp_assert_equal (isclose (a , b , rtol = 2 ), xp .asarray ([True , True , True ]))
369360
370- xp_assert_equal (allclose (a , b ), xp .asarray (False ))
371- xp_assert_equal (allclose (a , b , atol = 1 ), xp .asarray (True ))
372- xp_assert_equal (allclose (a , b , atol = 2 ), xp .asarray (True ))
373- xp_assert_equal (allclose (a , b , rtol = 1 ), xp .asarray (True ))
374- xp_assert_equal (allclose (a , b , rtol = 2 ), xp .asarray (True ))
375-
376361 # Test broadcasting
377362 xp_assert_equal (
378363 isclose (a , xp .asarray (True ), atol = 1 ), xp .asarray ([True , True , True ])
0 commit comments