@@ -1343,15 +1343,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
13431343 else :
13441344 return self .assertWarnsRegex (DeprecationWarning , message )
13451345
1346- def assertArraysEqual (self , x , y , * , check_dtypes = True , err_msg = '' ,
1346+ def assertArraysEqual (self , actual , desired , * , check_dtypes = True , err_msg = '' ,
13471347 allow_object_dtype = False , verbose = True ):
13481348 """Assert that x and y arrays are exactly equal."""
13491349 if check_dtypes :
1350- self .assertDtypesMatch (x , y )
1351- x = np .asarray (x )
1352- y = np .asarray (y )
1350+ self .assertDtypesMatch (actual , desired )
1351+ actual = np .asarray (actual )
1352+ desired = np .asarray (desired )
13531353
1354- if (not allow_object_dtype ) and (x .dtype == object or y .dtype == object ):
1354+ if (not allow_object_dtype ) and (actual .dtype == object or desired .dtype == object ):
13551355 # See https://github.com/jax-ml/jax/issues/17867
13561356 raise TypeError (
13571357 "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
@@ -1361,57 +1361,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
13611361
13621362 # Work around https://github.com/numpy/numpy/issues/18992
13631363 with np .errstate (over = 'ignore' ):
1364- np .testing .assert_array_equal (x , y , err_msg = err_msg ,
1364+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg ,
13651365 verbose = verbose )
13661366
1367- def assertArraysAllClose (self , x , y , * , check_dtypes = True , atol = None ,
1367+ def assertArraysAllClose (self , actual , desired , * , check_dtypes = True , atol = None ,
13681368 rtol = None , err_msg = '' ):
1369- """Assert that x and y are close (up to numerical tolerances)."""
1370- self .assertEqual (x .shape , y .shape )
1371- atol = max (tolerance (_dtype (x ), atol ), tolerance (_dtype (y ), atol ))
1372- rtol = max (tolerance (_dtype (x ), rtol ), tolerance (_dtype (y ), rtol ))
1369+ """Assert that actual and desired are close (up to numerical tolerances)."""
1370+ self .assertEqual (actual .shape , desired .shape )
1371+ atol = max (tolerance (_dtype (actual ), atol ), tolerance (_dtype (desired ), atol ))
1372+ rtol = max (tolerance (_dtype (actual ), rtol ), tolerance (_dtype (desired ), rtol ))
13731373
1374- _assert_numpy_allclose (x , y , atol = atol , rtol = rtol , err_msg = err_msg )
1374+ _assert_numpy_allclose (actual , desired , atol = atol , rtol = rtol , err_msg = err_msg )
13751375
13761376 if check_dtypes :
1377- self .assertDtypesMatch (x , y )
1377+ self .assertDtypesMatch (actual , desired )
13781378
1379- def assertDtypesMatch (self , x , y , * , canonicalize_dtypes = True ):
1379+ def assertDtypesMatch (self , actual , desired , * , canonicalize_dtypes = True ):
13801380 if not config .enable_x64 .value and canonicalize_dtypes :
1381- self .assertEqual (_dtypes .canonicalize_dtype (_dtype (x ), allow_extended_dtype = True ),
1382- _dtypes .canonicalize_dtype (_dtype (y ), allow_extended_dtype = True ))
1381+ self .assertEqual (_dtypes .canonicalize_dtype (_dtype (actual ), allow_extended_dtype = True ),
1382+ _dtypes .canonicalize_dtype (_dtype (desired ), allow_extended_dtype = True ))
13831383 else :
1384- self .assertEqual (_dtype (x ), _dtype (y ))
1384+ self .assertEqual (_dtype (actual ), _dtype (desired ))
13851385
1386- def assertAllClose (self , x , y , * , check_dtypes = True , atol = None , rtol = None ,
1386+ def assertAllClose (self , actual , desired , * , check_dtypes = True , atol = None , rtol = None ,
13871387 canonicalize_dtypes = True , err_msg = '' ):
1388- """Assert that x and y , either arrays or nested tuples/lists, are close."""
1389- if isinstance (x , dict ):
1390- self .assertIsInstance (y , dict )
1391- self .assertEqual (set (x .keys ()), set (y .keys ()))
1392- for k in x .keys ():
1393- self .assertAllClose (x [k ], y [k ], check_dtypes = check_dtypes , atol = atol ,
1388+ """Assert that actual and desired , either arrays or nested tuples/lists, are close."""
1389+ if isinstance (actual , dict ):
1390+ self .assertIsInstance (desired , dict )
1391+ self .assertEqual (set (actual .keys ()), set (desired .keys ()))
1392+ for k in actual .keys ():
1393+ self .assertAllClose (actual [k ], desired [k ], check_dtypes = check_dtypes , atol = atol ,
13941394 rtol = rtol , canonicalize_dtypes = canonicalize_dtypes ,
13951395 err_msg = err_msg )
1396- elif is_sequence (x ) and not hasattr (x , '__array__' ):
1397- self .assertTrue (is_sequence (y ) and not hasattr (y , '__array__' ))
1398- self .assertEqual (len (x ), len (y ))
1399- for x_elt , y_elt in zip (x , y ):
1400- self .assertAllClose (x_elt , y_elt , check_dtypes = check_dtypes , atol = atol ,
1396+ elif is_sequence (actual ) and not hasattr (actual , '__array__' ):
1397+ self .assertTrue (is_sequence (desired ) and not hasattr (desired , '__array__' ))
1398+ self .assertEqual (len (actual ), len (desired ))
1399+ for actual_elt , desired_elt in zip (actual , desired ):
1400+ self .assertAllClose (actual_elt , desired_elt , check_dtypes = check_dtypes , atol = atol ,
14011401 rtol = rtol , canonicalize_dtypes = canonicalize_dtypes ,
14021402 err_msg = err_msg )
1403- elif hasattr (x , '__array__' ) or np .isscalar (x ):
1404- self .assertTrue (hasattr (y , '__array__' ) or np .isscalar (y ))
1403+ elif hasattr (actual , '__array__' ) or np .isscalar (actual ):
1404+ self .assertTrue (hasattr (desired , '__array__' ) or np .isscalar (desired ))
14051405 if check_dtypes :
1406- self .assertDtypesMatch (x , y , canonicalize_dtypes = canonicalize_dtypes )
1407- x = np .asarray (x )
1408- y = np .asarray (y )
1409- self .assertArraysAllClose (x , y , check_dtypes = False , atol = atol , rtol = rtol ,
1406+ self .assertDtypesMatch (actual , desired , canonicalize_dtypes = canonicalize_dtypes )
1407+ actual = np .asarray (actual )
1408+ desired = np .asarray (desired )
1409+ self .assertArraysAllClose (actual , desired , check_dtypes = False , atol = atol , rtol = rtol ,
14101410 err_msg = err_msg )
1411- elif x == y :
1411+ elif actual == desired :
14121412 return
14131413 else :
1414- raise TypeError ((type (x ), type (y )))
1414+ raise TypeError ((type (actual ), type (desired )))
14151415
14161416 def assertMultiLineStrippedEqual (self , expected , what ):
14171417 """Asserts two strings are equal, after dedenting and stripping each line."""
@@ -1426,7 +1426,6 @@ def assertMultiLineStrippedEqual(self, expected, what):
14261426 self .assertMultiLineEqual (expected_clean , what_clean ,
14271427 msg = f"Found\n { what } \n Expecting\n { expected } " )
14281428
1429-
14301429 @contextmanager
14311430 def assertNoWarnings (self ):
14321431 with test_warning_util .raise_on_warnings ():
@@ -1496,9 +1495,9 @@ def wrapped_fun(*args):
14961495 python_should_be_executing = False
14971496 compiled_ans = cfun (* args )
14981497
1499- self .assertAllClose (python_ans , monitored_ans , check_dtypes = check_dtypes ,
1498+ self .assertAllClose (monitored_ans , python_ans , check_dtypes = check_dtypes ,
15001499 atol = atol or tol , rtol = rtol or tol )
1501- self .assertAllClose (python_ans , compiled_ans , check_dtypes = check_dtypes ,
1500+ self .assertAllClose (compiled_ans , python_ans , check_dtypes = check_dtypes ,
15021501 atol = atol or tol , rtol = rtol or tol )
15031502
15041503 args = args_maker ()
@@ -1509,7 +1508,7 @@ def wrapped_fun(*args):
15091508 python_should_be_executing = False
15101509 compiled_ans = cfun (* args )
15111510
1512- self .assertAllClose (python_ans , compiled_ans , check_dtypes = check_dtypes ,
1511+ self .assertAllClose (compiled_ans , python_ans , check_dtypes = check_dtypes ,
15131512 atol = atol or tol , rtol = rtol or tol )
15141513
15151514 def _CheckAgainstNumpy (self , numpy_reference_op , lax_op , args_maker ,
0 commit comments