Skip to content

Commit afdbcd0

Browse files
Merge pull request jax-ml#27255 from pearu:pearu/assertAllClose
PiperOrigin-RevId: 738496186
2 parents cf21f73 + 5a5415b commit afdbcd0

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

jax/_src/test_util.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,15 +1348,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
13481348
else:
13491349
return self.assertWarnsRegex(DeprecationWarning, message)
13501350

1351-
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
1351+
def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='',
13521352
allow_object_dtype=False, verbose=True):
13531353
"""Assert that x and y arrays are exactly equal."""
13541354
if check_dtypes:
1355-
self.assertDtypesMatch(x, y)
1356-
x = np.asarray(x)
1357-
y = np.asarray(y)
1355+
self.assertDtypesMatch(actual, desired)
1356+
actual = np.asarray(actual)
1357+
desired = np.asarray(desired)
13581358

1359-
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
1359+
if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object):
13601360
# See https://github.com/jax-ml/jax/issues/17867
13611361
raise TypeError(
13621362
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
@@ -1366,57 +1366,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
13661366

13671367
# Work around https://github.com/numpy/numpy/issues/18992
13681368
with np.errstate(over='ignore'):
1369-
np.testing.assert_array_equal(x, y, err_msg=err_msg,
1369+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg,
13701370
verbose=verbose)
13711371

1372-
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
1372+
def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None,
13731373
rtol=None, err_msg=''):
1374-
"""Assert that x and y are close (up to numerical tolerances)."""
1375-
self.assertEqual(x.shape, y.shape)
1376-
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
1377-
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
1374+
"""Assert that actual and desired are close (up to numerical tolerances)."""
1375+
self.assertEqual(actual.shape, desired.shape)
1376+
atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol))
1377+
rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol))
13781378

1379-
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
1379+
_assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg)
13801380

13811381
if check_dtypes:
1382-
self.assertDtypesMatch(x, y)
1382+
self.assertDtypesMatch(actual, desired)
13831383

1384-
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
1384+
def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True):
13851385
if not config.enable_x64.value and canonicalize_dtypes:
1386-
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True),
1387-
_dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True))
1386+
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True),
1387+
_dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True))
13881388
else:
1389-
self.assertEqual(_dtype(x), _dtype(y))
1389+
self.assertEqual(_dtype(actual), _dtype(desired))
13901390

1391-
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
1391+
def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None,
13921392
canonicalize_dtypes=True, err_msg=''):
1393-
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
1394-
if isinstance(x, dict):
1395-
self.assertIsInstance(y, dict)
1396-
self.assertEqual(set(x.keys()), set(y.keys()))
1397-
for k in x.keys():
1398-
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
1393+
"""Assert that actual and desired, either arrays or nested tuples/lists, are close."""
1394+
if isinstance(actual, dict):
1395+
self.assertIsInstance(desired, dict)
1396+
self.assertEqual(set(actual.keys()), set(desired.keys()))
1397+
for k in actual.keys():
1398+
self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol,
13991399
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
14001400
err_msg=err_msg)
1401-
elif is_sequence(x) and not hasattr(x, '__array__'):
1402-
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
1403-
self.assertEqual(len(x), len(y))
1404-
for x_elt, y_elt in zip(x, y):
1405-
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
1401+
elif is_sequence(actual) and not hasattr(actual, '__array__'):
1402+
self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'))
1403+
self.assertEqual(len(actual), len(desired))
1404+
for actual_elt, desired_elt in zip(actual, desired):
1405+
self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol,
14061406
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
14071407
err_msg=err_msg)
1408-
elif hasattr(x, '__array__') or np.isscalar(x):
1409-
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
1408+
elif hasattr(actual, '__array__') or np.isscalar(actual):
1409+
self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired))
14101410
if check_dtypes:
1411-
self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
1412-
x = np.asarray(x)
1413-
y = np.asarray(y)
1414-
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
1411+
self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes)
1412+
actual = np.asarray(actual)
1413+
desired = np.asarray(desired)
1414+
self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol,
14151415
err_msg=err_msg)
1416-
elif x == y:
1416+
elif actual == desired:
14171417
return
14181418
else:
1419-
raise TypeError((type(x), type(y)))
1419+
raise TypeError((type(actual), type(desired)))
14201420

14211421
def assertMultiLineStrippedEqual(self, expected, what):
14221422
"""Asserts two strings are equal, after dedenting and stripping each line."""
@@ -1431,7 +1431,6 @@ def assertMultiLineStrippedEqual(self, expected, what):
14311431
self.assertMultiLineEqual(expected_clean, what_clean,
14321432
msg=f"Found\n{what}\nExpecting\n{expected}")
14331433

1434-
14351434
@contextmanager
14361435
def assertNoWarnings(self):
14371436
with test_warning_util.raise_on_warnings():
@@ -1501,9 +1500,9 @@ def wrapped_fun(*args):
15011500
python_should_be_executing = False
15021501
compiled_ans = cfun(*args)
15031502

1504-
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
1503+
self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes,
15051504
atol=atol or tol, rtol=rtol or tol)
1506-
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
1505+
self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes,
15071506
atol=atol or tol, rtol=rtol or tol)
15081507

15091508
args = args_maker()
@@ -1514,7 +1513,7 @@ def wrapped_fun(*args):
15141513
python_should_be_executing = False
15151514
compiled_ans = cfun(*args)
15161515

1517-
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
1516+
self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes,
15181517
atol=atol or tol, rtol=rtol or tol)
15191518

15201519
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,

0 commit comments

Comments
 (0)