@@ -1348,15 +1348,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
13481348 else :
13491349 return self .assertWarnsRegex (DeprecationWarning , message )
13501350
1351- def assertArraysEqual (self , x , y , * , check_dtypes = True , err_msg = '' ,
1351+ def assertArraysEqual (self , actual , desired , * , check_dtypes = True , err_msg = '' ,
13521352 allow_object_dtype = False , verbose = True ):
13531353 """Assert that x and y arrays are exactly equal."""
13541354 if check_dtypes :
1355- self .assertDtypesMatch (x , y )
1356- x = np .asarray (x )
1357- y = np .asarray (y )
1355+ self .assertDtypesMatch (actual , desired )
1356+ actual = np .asarray (actual )
1357+ desired = np .asarray (desired )
13581358
1359- if (not allow_object_dtype ) and (x .dtype == object or y .dtype == object ):
1359+ if (not allow_object_dtype ) and (actual .dtype == object or desired .dtype == object ):
13601360 # See https://github.com/jax-ml/jax/issues/17867
13611361 raise TypeError (
13621362 "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
@@ -1366,57 +1366,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
13661366
13671367 # Work around https://github.com/numpy/numpy/issues/18992
13681368 with np .errstate (over = 'ignore' ):
1369- np .testing .assert_array_equal (x , y , err_msg = err_msg ,
1369+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg ,
13701370 verbose = verbose )
13711371
1372- def assertArraysAllClose (self , x , y , * , check_dtypes = True , atol = None ,
1372+ def assertArraysAllClose (self , actual , desired , * , check_dtypes = True , atol = None ,
13731373 rtol = None , err_msg = '' ):
1374- """Assert that x and y are close (up to numerical tolerances)."""
1375- self .assertEqual (x .shape , y .shape )
1376- atol = max (tolerance (_dtype (x ), atol ), tolerance (_dtype (y ), atol ))
1377- rtol = max (tolerance (_dtype (x ), rtol ), tolerance (_dtype (y ), rtol ))
1374+ """Assert that actual and desired are close (up to numerical tolerances)."""
1375+ self .assertEqual (actual .shape , desired .shape )
1376+ atol = max (tolerance (_dtype (actual ), atol ), tolerance (_dtype (desired ), atol ))
1377+ rtol = max (tolerance (_dtype (actual ), rtol ), tolerance (_dtype (desired ), rtol ))
13781378
1379- _assert_numpy_allclose (x , y , atol = atol , rtol = rtol , err_msg = err_msg )
1379+ _assert_numpy_allclose (actual , desired , atol = atol , rtol = rtol , err_msg = err_msg )
13801380
13811381 if check_dtypes :
1382- self .assertDtypesMatch (x , y )
1382+ self .assertDtypesMatch (actual , desired )
13831383
1384- def assertDtypesMatch (self , x , y , * , canonicalize_dtypes = True ):
1384+ def assertDtypesMatch (self , actual , desired , * , canonicalize_dtypes = True ):
13851385 if not config .enable_x64 .value and canonicalize_dtypes :
1386- self .assertEqual (_dtypes .canonicalize_dtype (_dtype (x ), allow_extended_dtype = True ),
1387- _dtypes .canonicalize_dtype (_dtype (y ), allow_extended_dtype = True ))
1386+ self .assertEqual (_dtypes .canonicalize_dtype (_dtype (actual ), allow_extended_dtype = True ),
1387+ _dtypes .canonicalize_dtype (_dtype (desired ), allow_extended_dtype = True ))
13881388 else :
1389- self .assertEqual (_dtype (x ), _dtype (y ))
1389+ self .assertEqual (_dtype (actual ), _dtype (desired ))
13901390
1391- def assertAllClose (self , x , y , * , check_dtypes = True , atol = None , rtol = None ,
1391+ def assertAllClose (self , actual , desired , * , check_dtypes = True , atol = None , rtol = None ,
13921392 canonicalize_dtypes = True , err_msg = '' ):
1393- """Assert that x and y , either arrays or nested tuples/lists, are close."""
1394- if isinstance (x , dict ):
1395- self .assertIsInstance (y , dict )
1396- self .assertEqual (set (x .keys ()), set (y .keys ()))
1397- for k in x .keys ():
1398- self .assertAllClose (x [k ], y [k ], check_dtypes = check_dtypes , atol = atol ,
1393+ """Assert that actual and desired , either arrays or nested tuples/lists, are close."""
1394+ if isinstance (actual , dict ):
1395+ self .assertIsInstance (desired , dict )
1396+ self .assertEqual (set (actual .keys ()), set (desired .keys ()))
1397+ for k in actual .keys ():
1398+ self .assertAllClose (actual [k ], desired [k ], check_dtypes = check_dtypes , atol = atol ,
13991399 rtol = rtol , canonicalize_dtypes = canonicalize_dtypes ,
14001400 err_msg = err_msg )
1401- elif is_sequence (x ) and not hasattr (x , '__array__' ):
1402- self .assertTrue (is_sequence (y ) and not hasattr (y , '__array__' ))
1403- self .assertEqual (len (x ), len (y ))
1404- for x_elt , y_elt in zip (x , y ):
1405- self .assertAllClose (x_elt , y_elt , check_dtypes = check_dtypes , atol = atol ,
1401+ elif is_sequence (actual ) and not hasattr (actual , '__array__' ):
1402+ self .assertTrue (is_sequence (desired ) and not hasattr (desired , '__array__' ))
1403+ self .assertEqual (len (actual ), len (desired ))
1404+ for actual_elt , desired_elt in zip (actual , desired ):
1405+ self .assertAllClose (actual_elt , desired_elt , check_dtypes = check_dtypes , atol = atol ,
14061406 rtol = rtol , canonicalize_dtypes = canonicalize_dtypes ,
14071407 err_msg = err_msg )
1408- elif hasattr (x , '__array__' ) or np .isscalar (x ):
1409- self .assertTrue (hasattr (y , '__array__' ) or np .isscalar (y ))
1408+ elif hasattr (actual , '__array__' ) or np .isscalar (actual ):
1409+ self .assertTrue (hasattr (desired , '__array__' ) or np .isscalar (desired ))
14101410 if check_dtypes :
1411- self .assertDtypesMatch (x , y , canonicalize_dtypes = canonicalize_dtypes )
1412- x = np .asarray (x )
1413- y = np .asarray (y )
1414- self .assertArraysAllClose (x , y , check_dtypes = False , atol = atol , rtol = rtol ,
1411+ self .assertDtypesMatch (actual , desired , canonicalize_dtypes = canonicalize_dtypes )
1412+ actual = np .asarray (actual )
1413+ desired = np .asarray (desired )
1414+ self .assertArraysAllClose (actual , desired , check_dtypes = False , atol = atol , rtol = rtol ,
14151415 err_msg = err_msg )
1416- elif x == y :
1416+ elif actual == desired :
14171417 return
14181418 else :
1419- raise TypeError ((type (x ), type (y )))
1419+ raise TypeError ((type (actual ), type (desired )))
14201420
14211421 def assertMultiLineStrippedEqual (self , expected , what ):
14221422 """Asserts two strings are equal, after dedenting and stripping each line."""
@@ -1431,7 +1431,6 @@ def assertMultiLineStrippedEqual(self, expected, what):
14311431 self .assertMultiLineEqual (expected_clean , what_clean ,
14321432 msg = f"Found\n { what } \n Expecting\n { expected } " )
14331433
1434-
14351434 @contextmanager
14361435 def assertNoWarnings (self ):
14371436 with test_warning_util .raise_on_warnings ():
@@ -1501,9 +1500,9 @@ def wrapped_fun(*args):
15011500 python_should_be_executing = False
15021501 compiled_ans = cfun (* args )
15031502
1504- self .assertAllClose (python_ans , monitored_ans , check_dtypes = check_dtypes ,
1503+ self .assertAllClose (monitored_ans , python_ans , check_dtypes = check_dtypes ,
15051504 atol = atol or tol , rtol = rtol or tol )
1506- self .assertAllClose (python_ans , compiled_ans , check_dtypes = check_dtypes ,
1505+ self .assertAllClose (compiled_ans , python_ans , check_dtypes = check_dtypes ,
15071506 atol = atol or tol , rtol = rtol or tol )
15081507
15091508 args = args_maker ()
@@ -1514,7 +1513,7 @@ def wrapped_fun(*args):
15141513 python_should_be_executing = False
15151514 compiled_ans = cfun (* args )
15161515
1517- self .assertAllClose (python_ans , compiled_ans , check_dtypes = check_dtypes ,
1516+ self .assertAllClose (compiled_ans , python_ans , check_dtypes = check_dtypes ,
15181517 atol = atol or tol , rtol = rtol or tol )
15191518
15201519 def _CheckAgainstNumpy (self , numpy_reference_op , lax_op , args_maker ,
0 commit comments