3535import time
3636from typing import Any , TextIO
3737import unittest
38- import warnings
3938import zlib
4039
4140from absl .testing import absltest
4948from jax ._src import dtypes as _dtypes
5049from jax ._src import lib as _jaxlib
5150from jax ._src import monitoring
51+ from jax ._src import test_warning_util
5252from jax ._src import xla_bridge
5353from jax ._src import util
5454from jax ._src import mesh as mesh_lib
118118)
119119
120120TEST_NUM_THREADS = config .int_flag (
121- 'jax_test_num_threads' , 0 ,
121+ 'jax_test_num_threads' , int ( os . getenv ( 'JAX_TEST_NUM_THREADS' , '0' )) ,
122122 help = 'Number of threads to use for running tests. 0 means run everything '
123123 'in the main thread. Using > 1 thread is experimental.'
124124)
@@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase):
10761076 with self .lock :
10771077 # We assume test_result is an ABSL _TextAndXMLTestResult, so we can
10781078 # override how it gets the time.
1079- time_getter = self .test_result . time_getter
1079+ time_getter = getattr ( self .test_result , " time_getter" , None )
10801080 try :
10811081 self .test_result .time_getter = lambda : self .start_time
10821082 self .test_result .startTest (test )
@@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase):
10851085 self .test_result .time_getter = lambda : stop_time
10861086 self .test_result .stopTest (test )
10871087 finally :
1088- self .test_result .time_getter = time_getter
1088+ if time_getter is not None :
1089+ self .test_result .time_getter = time_getter
10891090
10901091 def addSuccess (self , test : unittest .TestCase ):
10911092 self .actions .append (lambda : self .test_result .addSuccess (test ))
@@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test
11201121 if TEST_NUM_THREADS .value <= 0 :
11211122 return super ().run (result )
11221123
1124+ test_warning_util .install_threadsafe_warning_handlers ()
1125+
11231126 executor = ThreadPoolExecutor (TEST_NUM_THREADS .value )
11241127 lock = threading .Lock ()
11251128 futures = []
@@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what):
13681371 self .assertMultiLineEqual (expected_clean , what_clean ,
13691372 msg = f"Found\n { what } \n Expecting\n { expected } " )
13701373
1374+
13711375 @contextmanager
13721376 def assertNoWarnings (self ):
1373- with warnings .catch_warnings ():
1374- warnings .simplefilter ("error" )
1377+ with test_warning_util .raise_on_warnings ():
1378+ yield
1379+
1380+ # We replace assertWarns and assertWarnsRegex with functions that use the
1381+ # thread-safe warning utilities. Unlike the unittest versions these only
1382+ # function as context managers.
1383+ @contextmanager
1384+ def assertWarns (self , warning , * , msg = None ):
1385+ with test_warning_util .record_warnings () as ws :
1386+ yield
1387+ for w in ws :
1388+ if not isinstance (w .message , warning ):
1389+ continue
1390+ if msg is not None and msg not in str (w .message ):
1391+ continue
1392+ return
1393+ self .fail (f"Expected warning not found { warning } :'{ msg } ', got "
1394+ f"{ ws } " )
1395+
1396+ @contextmanager
1397+ def assertWarnsRegex (self , warning , regex ):
1398+ if regex is not None :
1399+ regex = re .compile (regex )
1400+
1401+ with test_warning_util .record_warnings () as ws :
13751402 yield
1403+ for w in ws :
1404+ if not isinstance (w .message , warning ):
1405+ continue
1406+ if regex is not None and not regex .search (str (w .message )):
1407+ continue
1408+ return
1409+ self .fail (f"Expected warning not found { warning } :'{ regex } ', got "
1410+ f"{ ws } " )
1411+
13761412
13771413 def _CompileAndCheck (self , fun , args_maker , * , check_dtypes = True , tol = None ,
13781414 rtol = None , atol = None , check_cache_misses = True ):
@@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x):
14491485 self .assertFalse (x .is_deleted ())
14501486
14511487
1452- @contextmanager
1453- def ignore_warning (* , message = '' , category = Warning , ** kw ):
1454- with warnings .catch_warnings ():
1455- warnings .filterwarnings ("ignore" , message = message , category = category , ** kw )
1456- yield
1488+ ignore_warning = test_warning_util .ignore_warning
14571489
14581490# -------------------- Mesh parametrization helpers --------------------
14591491
@@ -1768,9 +1800,8 @@ def make_axis_points(size):
17681800 logtiny = finfo .minexp / prec_dps_ratio
17691801 axis_points = np .zeros (3 + 2 * size , dtype = finfo .dtype )
17701802
1771- with warnings . catch_warnings ( ):
1803+ with ignore_warning ( category = RuntimeWarning ):
17721804 # Silence RuntimeWarning: overflow encountered in cast
1773- warnings .simplefilter ("ignore" )
17741805 half_neg_line = - np .logspace (logmin , logtiny , size , dtype = finfo .dtype )
17751806 half_line = - half_neg_line [::- 1 ]
17761807 axis_points [- size - 1 :- 1 ] = half_line
0 commit comments