Skip to content

Commit 5a5415b

Browse files
committed
Rename arguments x, y of assertAllClose and friends to actual, expected.
1 parent ff751ec commit 5a5415b

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

jax/_src/test_util.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,15 +1343,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
13431343
else:
13441344
return self.assertWarnsRegex(DeprecationWarning, message)
13451345

1346-
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
1346+
def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='',
13471347
allow_object_dtype=False, verbose=True):
13481348
"""Assert that x and y arrays are exactly equal."""
13491349
if check_dtypes:
1350-
self.assertDtypesMatch(x, y)
1351-
x = np.asarray(x)
1352-
y = np.asarray(y)
1350+
self.assertDtypesMatch(actual, desired)
1351+
actual = np.asarray(actual)
1352+
desired = np.asarray(desired)
13531353

1354-
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
1354+
if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object):
13551355
# See https://github.com/jax-ml/jax/issues/17867
13561356
raise TypeError(
13571357
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
@@ -1361,57 +1361,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
13611361

13621362
# Work around https://github.com/numpy/numpy/issues/18992
13631363
with np.errstate(over='ignore'):
1364-
np.testing.assert_array_equal(x, y, err_msg=err_msg,
1364+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg,
13651365
verbose=verbose)
13661366

1367-
def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None,
1367+
def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None,
13681368
rtol=None, err_msg=''):
1369-
"""Assert that x and y are close (up to numerical tolerances)."""
1370-
self.assertEqual(x.shape, y.shape)
1371-
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
1372-
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
1369+
"""Assert that actual and desired are close (up to numerical tolerances)."""
1370+
self.assertEqual(actual.shape, desired.shape)
1371+
atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol))
1372+
rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol))
13731373

1374-
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
1374+
_assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg)
13751375

13761376
if check_dtypes:
1377-
self.assertDtypesMatch(x, y)
1377+
self.assertDtypesMatch(actual, desired)
13781378

1379-
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
1379+
def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True):
13801380
if not config.enable_x64.value and canonicalize_dtypes:
1381-
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True),
1382-
_dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True))
1381+
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True),
1382+
_dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True))
13831383
else:
1384-
self.assertEqual(_dtype(x), _dtype(y))
1384+
self.assertEqual(_dtype(actual), _dtype(desired))
13851385

1386-
def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None,
1386+
def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None,
13871387
canonicalize_dtypes=True, err_msg=''):
1388-
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
1389-
if isinstance(x, dict):
1390-
self.assertIsInstance(y, dict)
1391-
self.assertEqual(set(x.keys()), set(y.keys()))
1392-
for k in x.keys():
1393-
self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol,
1388+
"""Assert that actual and desired, either arrays or nested tuples/lists, are close."""
1389+
if isinstance(actual, dict):
1390+
self.assertIsInstance(desired, dict)
1391+
self.assertEqual(set(actual.keys()), set(desired.keys()))
1392+
for k in actual.keys():
1393+
self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol,
13941394
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
13951395
err_msg=err_msg)
1396-
elif is_sequence(x) and not hasattr(x, '__array__'):
1397-
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
1398-
self.assertEqual(len(x), len(y))
1399-
for x_elt, y_elt in zip(x, y):
1400-
self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol,
1396+
elif is_sequence(actual) and not hasattr(actual, '__array__'):
1397+
self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'))
1398+
self.assertEqual(len(actual), len(desired))
1399+
for actual_elt, desired_elt in zip(actual, desired):
1400+
self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol,
14011401
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
14021402
err_msg=err_msg)
1403-
elif hasattr(x, '__array__') or np.isscalar(x):
1404-
self.assertTrue(hasattr(y, '__array__') or np.isscalar(y))
1403+
elif hasattr(actual, '__array__') or np.isscalar(actual):
1404+
self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired))
14051405
if check_dtypes:
1406-
self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes)
1407-
x = np.asarray(x)
1408-
y = np.asarray(y)
1409-
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
1406+
self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes)
1407+
actual = np.asarray(actual)
1408+
desired = np.asarray(desired)
1409+
self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol,
14101410
err_msg=err_msg)
1411-
elif x == y:
1411+
elif actual == desired:
14121412
return
14131413
else:
1414-
raise TypeError((type(x), type(y)))
1414+
raise TypeError((type(actual), type(desired)))
14151415

14161416
def assertMultiLineStrippedEqual(self, expected, what):
14171417
"""Asserts two strings are equal, after dedenting and stripping each line."""
@@ -1426,7 +1426,6 @@ def assertMultiLineStrippedEqual(self, expected, what):
14261426
self.assertMultiLineEqual(expected_clean, what_clean,
14271427
msg=f"Found\n{what}\nExpecting\n{expected}")
14281428

1429-
14301429
@contextmanager
14311430
def assertNoWarnings(self):
14321431
with test_warning_util.raise_on_warnings():
@@ -1496,9 +1495,9 @@ def wrapped_fun(*args):
14961495
python_should_be_executing = False
14971496
compiled_ans = cfun(*args)
14981497

1499-
self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes,
1498+
self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes,
15001499
atol=atol or tol, rtol=rtol or tol)
1501-
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
1500+
self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes,
15021501
atol=atol or tol, rtol=rtol or tol)
15031502

15041503
args = args_maker()
@@ -1509,7 +1508,7 @@ def wrapped_fun(*args):
15091508
python_should_be_executing = False
15101509
compiled_ans = cfun(*args)
15111510

1512-
self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes,
1511+
self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes,
15131512
atol=atol or tol, rtol=rtol or tol)
15141513

15151514
def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker,

0 commit comments

Comments
 (0)