1515import contextlib
1616import io
1717import logging
18+ import os
1819import platform
1920import re
20- import shlex
2121import subprocess
2222import sys
2323import tempfile
@@ -78,6 +78,31 @@ def capture_jax_logs():
7878 logger .removeHandler (handler )
7979
8080
81+ # Saves and runs script from the file in order to fix the problem with
82+ # `tsl::Env::Default()->GetExecutablePath()` not working properly with
83+ # command flag.
84+ def _run (program , env_var = {}):
85+ # strip the leading whitespace from the program script
86+ program = re .sub (r"^\s+" , "" , program , flags = re .MULTILINE )
87+
88+ with tempfile .NamedTemporaryFile (
89+ mode = "w+" , encoding = "utf-8" , suffix = ".py" , dir = os .getcwd ()
90+ ) as f :
91+ f .write (textwrap .dedent (program ))
92+ f .flush ()
93+ python = sys .executable
94+ assert "python" in python
95+ if env_var :
96+ env_var .update (os .environ )
97+ else :
98+ env_var = os .environ
99+
100+ # Make sure C++ logging is at default level for the test process.
101+ p = subprocess .run ([python , f .name ], env = env_var , capture_output = True , text = True )
102+
103+ return type ("" , (object ,), { "stdout" : p .stdout , "stderr" : p .stderr })
104+
105+
81106class LoggingTest (jtu .JaxTestCase ):
82107
83108 @unittest .skipIf (platform .system () == "Windows" ,
@@ -90,36 +115,25 @@ def test_no_log_spam(self):
90115 if sys .executable is None :
91116 raise self .skipTest ("test requires access to python binary" )
92117
93- # Save script in file to fix the problem with
94- # `tsl::Env::Default()->GetExecutablePath()` not working properly with
95- # command flag.
96- with tempfile .NamedTemporaryFile (
97- mode = "w+" , encoding = "utf-8" , suffix = ".py"
98- ) as f :
99- f .write (textwrap .dedent ("""
118+ o = _run ("""
100119 import jax
101120 jax.device_count()
102121 f = jax.jit(lambda x: x + 1)
103122 f(1)
104123 f(2)
105124 jax.numpy.add(1, 1)
106- """ ))
107- python = sys .executable
108- assert "python" in python
109- # Make sure C++ logging is at default level for the test process.
110- proc = subprocess .run ([python , f .name ], capture_output = True )
111-
112- lines = proc .stdout .split (b"\n " )
113- lines .extend (proc .stderr .split (b"\n " ))
114- allowlist = [
115- b"" ,
116- (
117- b"An NVIDIA GPU may be present on this machine, but a"
118- b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
119- ),
120- ]
121- lines = [l for l in lines if l not in allowlist ]
122- self .assertEmpty (lines )
125+ """ )
126+
127+ lines = o .stdout .split ("\n " )
128+ lines .extend (o .stderr .split ("\n " ))
129+ allowlist = [
130+ (
131+ "An NVIDIA GPU may be present on this machine, but a"
132+ " CUDA-enabled jaxlib is not installed. Falling back to cpu."
133+ ),
134+ ]
135+ lines = [l for l in lines if l in allowlist ]
136+ self .assertEmpty (lines )
123137
124138 def test_debug_logging (self ):
125139 # Warmup so we don't get "No GPU/TPU" warning later.
@@ -164,19 +178,12 @@ def test_subprocess_stderr_info_logging(self):
164178 if sys .executable is None :
165179 raise self .skipTest ("test requires access to python binary" )
166180
167- program = """
168- import jax # this prints INFO logging from backend imports
169- jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
170- """
171-
172- # strip the leading whitespace from the program script
173- program = re .sub (r"^\s+" , "" , program , flags = re .MULTILINE )
181+ o = _run ("""
182+ import jax # this prints INFO logging from backend imports
183+ jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
184+ """ , { "JAX_LOGGING_LEVEL" : "INFO" })
174185
175- # test INFO
176- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=INFO { sys .executable } -c"
177- f" '{ program } '" )
178- p = subprocess .run (cmd , capture_output = True , text = True )
179- log_output = p .stderr
186+ log_output = o .stderr
180187 info_lines = log_output .split ("\n " )
181188 self .assertGreater (len (info_lines ), 0 )
182189 self .assertIn ("INFO" , log_output )
@@ -194,22 +201,14 @@ def test_subprocess_stderr_debug_logging(self):
194201 jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
195202 """
196203
197- # strip the leading whitespace from the program script
198- program = re .sub (r"^\s+" , "" , program , flags = re .MULTILINE )
204+ o = _run (program , { "JAX_LOGGING_LEVEL" : "DEBUG" })
199205
200- # test DEBUG
201- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=DEBUG { sys .executable } -c"
202- f" '{ program } '" )
203- p = subprocess .run (cmd , capture_output = True , text = True )
204- log_output = p .stderr
206+ log_output = o .stderr
205207 self .assertIn ("INFO" , log_output )
206208 self .assertIn ("DEBUG" , log_output )
207209
208- # test JAX_DEBUG_MODULES
209- cmd = shlex .split (f"env JAX_DEBUG_LOG_MODULES=jax { sys .executable } -c"
210- f" '{ program } '" )
211- p = subprocess .run (cmd , capture_output = True , text = True )
212- log_output = p .stderr
210+ o = _run (program , { "JAX_DEBUG_LOG_MODULES" : "jax" })
211+ log_output = o .stderr
213212 self .assertIn ("DEBUG" , log_output )
214213
215214 @jtu .skip_on_devices ("tpu" )
@@ -220,22 +219,15 @@ def test_subprocess_toggling_logging_level(self):
220219 raise self .skipTest ("test requires access to python binary" )
221220
222221 _separator = "---------------------------"
223- program = f"""
222+ o = _run ( f"""
224223 import sys
225224 import jax # this prints INFO logging from backend imports
226225 jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
227226 jax.config.update("jax_logging_level", None)
228227 sys.stderr.write("{ _separator } ")
229228 jax.jit(lambda x: x)(1) # should not log anything now
230- """
231-
232- # strip the leading whitespace from the program script
233- program = re .sub (r"^\s+" , "" , program , flags = re .MULTILINE )
234-
235- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=DEBUG { sys .executable } -c"
236- f" '{ program } '" )
237- p = subprocess .run (cmd , capture_output = True , text = True )
238- log_output = p .stderr
229+ """ , {"JAX_LOGGING_LEVEL" : "DEBUG" })
230+ log_output = o .stderr
239231 m = re .search (_separator , log_output )
240232 self .assertTrue (m is not None )
241233 log_output_verbose = log_output [:m .start ()]
@@ -252,19 +244,13 @@ def test_subprocess_double_logging_absent(self):
252244 if sys .executable is None :
253245 raise self .skipTest ("test requires access to python binary" )
254246
255- program = """
247+ o = _run ( """
256248 import jax # this prints INFO logging from backend imports
257249 jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
258250 jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
259- """
260-
261- # strip the leading whitespace from the program script
262- program = re .sub (r"^\s+" , "" , program , flags = re .MULTILINE )
251+ """ , { "JAX_LOGGING_LEVEL" : "DEBUG" })
263252
264- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=DEBUG { sys .executable } -c"
265- f" '{ program } '" )
266- p = subprocess .run (cmd , capture_output = True , text = True )
267- log_output = p .stderr
253+ log_output = o .stderr
268254 self .assertNotEmpty (log_output )
269255 log_lines = log_output .strip ().split ("\n " )
270256 # only one tracing line should be printed, if there's more than one
@@ -285,31 +271,19 @@ def test_subprocess_cpp_logging_level(self):
285271 jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
286272 """
287273
288- # strip the leading whitespace from the program script
289- program = re . sub ( r"^\s+" , "" , program , flags = re . MULTILINE )
274+ o = _run ( program , { "JAX_LOGGING_LEVEL" : "DEBUG" })
275+ self . assertIn ( "Initializing CoordinationService" , o . stderr )
290276
291- # verbose logging: DEBUG, VERBOSE
292- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=DEBUG { sys .executable } -c"
293- f" '{ program } '" )
294- p = subprocess .run (cmd , capture_output = True , text = True )
295- self .assertIn ("Initializing CoordinationService" , p .stderr )
296-
297- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=INFO { sys .executable } -c"
298- f" '{ program } '" )
299- p = subprocess .run (cmd , capture_output = True , text = True )
300- self .assertIn ("Initializing CoordinationService" , p .stderr )
277+ o = _run (program , { "JAX_LOGGING_LEVEL" : "INFO" })
278+ self .assertIn ("Initializing CoordinationService" , o .stderr )
301279
302280 # verbose logging: WARNING, None
303- cmd = shlex .split (f"env JAX_LOGGING_LEVEL=WARNING { sys .executable } -c"
304- f" '{ program } '" )
305- p = subprocess .run (cmd , capture_output = True , text = True )
306- self .assertNotIn ("Initializing CoordinationService" , p .stderr )
307-
308- cmd = shlex .split (f"{ sys .executable } -c"
309- f" '{ program } '" )
310- p = subprocess .run (cmd , capture_output = True , text = True )
281+ o = _run (program , { "JAX_LOGGING_LEVEL" : "WARNING" })
282+ self .assertNotIn ("Initializing CoordinationService" , o .stderr )
283+
284+ o = _run (program )
311285 if int (_default_TF_CPP_MIN_LOG_LEVEL ) >= 1 :
312- self .assertNotIn ("Initializing CoordinationService" , p .stderr )
286+ self .assertNotIn ("Initializing CoordinationService" , o .stderr )
313287
314288if __name__ == "__main__" :
315289 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments