Skip to content

Commit 080804c

Browse files
Fix logging_test fails on Linux with NVIDIA Driver only.
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux). Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed. Solution: Save python program to file in runtime directory then run script from the file. PiperOrigin-RevId: 738152663
1 parent 54691b1 commit 080804c

File tree

1 file changed

+63
-89
lines changed

1 file changed

+63
-89
lines changed

tests/logging_test.py

Lines changed: 63 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import contextlib
1616
import io
1717
import logging
18+
import os
1819
import platform
1920
import re
20-
import shlex
2121
import subprocess
2222
import sys
2323
import tempfile
@@ -78,6 +78,31 @@ def capture_jax_logs():
7878
logger.removeHandler(handler)
7979

8080

81+
# Saves and runs script from the file in order to fix the problem with
82+
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
83+
# command flag.
84+
def _run(program, env_var = {}):
85+
# strip the leading whitespace from the program script
86+
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
87+
88+
with tempfile.NamedTemporaryFile(
89+
mode="w+", encoding="utf-8", suffix=".py", dir=os.getcwd()
90+
) as f:
91+
f.write(textwrap.dedent(program))
92+
f.flush()
93+
python = sys.executable
94+
assert "python" in python
95+
if env_var:
96+
env_var.update(os.environ)
97+
else:
98+
env_var = os.environ
99+
100+
# Make sure C++ logging is at default level for the test process.
101+
p = subprocess.run([python, f.name], env=env_var, capture_output=True, text=True)
102+
103+
return type("", (object,), { "stdout": p.stdout, "stderr": p.stderr })
104+
105+
81106
class LoggingTest(jtu.JaxTestCase):
82107

83108
@unittest.skipIf(platform.system() == "Windows",
@@ -90,36 +115,25 @@ def test_no_log_spam(self):
90115
if sys.executable is None:
91116
raise self.skipTest("test requires access to python binary")
92117

93-
# Save script in file to fix the problem with
94-
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
95-
# command flag.
96-
with tempfile.NamedTemporaryFile(
97-
mode="w+", encoding="utf-8", suffix=".py"
98-
) as f:
99-
f.write(textwrap.dedent("""
118+
o = _run("""
100119
import jax
101120
jax.device_count()
102121
f = jax.jit(lambda x: x + 1)
103122
f(1)
104123
f(2)
105124
jax.numpy.add(1, 1)
106-
"""))
107-
python = sys.executable
108-
assert "python" in python
109-
# Make sure C++ logging is at default level for the test process.
110-
proc = subprocess.run([python, f.name], capture_output=True)
111-
112-
lines = proc.stdout.split(b"\n")
113-
lines.extend(proc.stderr.split(b"\n"))
114-
allowlist = [
115-
b"",
116-
(
117-
b"An NVIDIA GPU may be present on this machine, but a"
118-
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
119-
),
120-
]
121-
lines = [l for l in lines if l not in allowlist]
122-
self.assertEmpty(lines)
125+
""")
126+
127+
lines = o.stdout.split("\n")
128+
lines.extend(o.stderr.split("\n"))
129+
allowlist = [
130+
(
131+
"An NVIDIA GPU may be present on this machine, but a"
132+
" CUDA-enabled jaxlib is not installed. Falling back to cpu."
133+
),
134+
]
135+
lines = [l for l in lines if l in allowlist]
136+
self.assertEmpty(lines)
123137

124138
def test_debug_logging(self):
125139
# Warmup so we don't get "No GPU/TPU" warning later.
@@ -164,19 +178,12 @@ def test_subprocess_stderr_info_logging(self):
164178
if sys.executable is None:
165179
raise self.skipTest("test requires access to python binary")
166180

167-
program = """
168-
import jax # this prints INFO logging from backend imports
169-
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
170-
"""
171-
172-
# strip the leading whitespace from the program script
173-
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
181+
o = _run("""
182+
import jax # this prints INFO logging from backend imports
183+
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
184+
""", { "JAX_LOGGING_LEVEL": "INFO" })
174185

175-
# test INFO
176-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
177-
f" '{program}'")
178-
p = subprocess.run(cmd, capture_output=True, text=True)
179-
log_output = p.stderr
186+
log_output = o.stderr
180187
info_lines = log_output.split("\n")
181188
self.assertGreater(len(info_lines), 0)
182189
self.assertIn("INFO", log_output)
@@ -194,22 +201,14 @@ def test_subprocess_stderr_debug_logging(self):
194201
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
195202
"""
196203

197-
# strip the leading whitespace from the program script
198-
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
204+
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
199205

200-
# test DEBUG
201-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
202-
f" '{program}'")
203-
p = subprocess.run(cmd, capture_output=True, text=True)
204-
log_output = p.stderr
206+
log_output = o.stderr
205207
self.assertIn("INFO", log_output)
206208
self.assertIn("DEBUG", log_output)
207209

208-
# test JAX_DEBUG_MODULES
209-
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
210-
f" '{program}'")
211-
p = subprocess.run(cmd, capture_output=True, text=True)
212-
log_output = p.stderr
210+
o = _run(program, { "JAX_DEBUG_LOG_MODULES": "jax" })
211+
log_output = o.stderr
213212
self.assertIn("DEBUG", log_output)
214213

215214
@jtu.skip_on_devices("tpu")
@@ -220,22 +219,15 @@ def test_subprocess_toggling_logging_level(self):
220219
raise self.skipTest("test requires access to python binary")
221220

222221
_separator = "---------------------------"
223-
program = f"""
222+
o = _run(f"""
224223
import sys
225224
import jax # this prints INFO logging from backend imports
226225
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
227226
jax.config.update("jax_logging_level", None)
228227
sys.stderr.write("{_separator}")
229228
jax.jit(lambda x: x)(1) # should not log anything now
230-
"""
231-
232-
# strip the leading whitespace from the program script
233-
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
234-
235-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
236-
f" '{program}'")
237-
p = subprocess.run(cmd, capture_output=True, text=True)
238-
log_output = p.stderr
229+
""", {"JAX_LOGGING_LEVEL": "DEBUG"})
230+
log_output = o.stderr
239231
m = re.search(_separator, log_output)
240232
self.assertTrue(m is not None)
241233
log_output_verbose = log_output[:m.start()]
@@ -252,19 +244,13 @@ def test_subprocess_double_logging_absent(self):
252244
if sys.executable is None:
253245
raise self.skipTest("test requires access to python binary")
254246

255-
program = """
247+
o = _run("""
256248
import jax # this prints INFO logging from backend imports
257249
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
258250
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
259-
"""
260-
261-
# strip the leading whitespace from the program script
262-
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
251+
""", { "JAX_LOGGING_LEVEL": "DEBUG" })
263252

264-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
265-
f" '{program}'")
266-
p = subprocess.run(cmd, capture_output=True, text=True)
267-
log_output = p.stderr
253+
log_output = o.stderr
268254
self.assertNotEmpty(log_output)
269255
log_lines = log_output.strip().split("\n")
270256
# only one tracing line should be printed, if there's more than one
@@ -285,31 +271,19 @@ def test_subprocess_cpp_logging_level(self):
285271
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
286272
"""
287273

288-
# strip the leading whitespace from the program script
289-
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
274+
o = _run(program, { "JAX_LOGGING_LEVEL": "DEBUG" })
275+
self.assertIn("Initializing CoordinationService", o.stderr)
290276

291-
# verbose logging: DEBUG, VERBOSE
292-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
293-
f" '{program}'")
294-
p = subprocess.run(cmd, capture_output=True, text=True)
295-
self.assertIn("Initializing CoordinationService", p.stderr)
296-
297-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
298-
f" '{program}'")
299-
p = subprocess.run(cmd, capture_output=True, text=True)
300-
self.assertIn("Initializing CoordinationService", p.stderr)
277+
o = _run(program, { "JAX_LOGGING_LEVEL": "INFO" })
278+
self.assertIn("Initializing CoordinationService", o.stderr)
301279

302280
# verbose logging: WARNING, None
303-
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
304-
f" '{program}'")
305-
p = subprocess.run(cmd, capture_output=True, text=True)
306-
self.assertNotIn("Initializing CoordinationService", p.stderr)
307-
308-
cmd = shlex.split(f"{sys.executable} -c"
309-
f" '{program}'")
310-
p = subprocess.run(cmd, capture_output=True, text=True)
281+
o = _run(program, { "JAX_LOGGING_LEVEL": "WARNING" })
282+
self.assertNotIn("Initializing CoordinationService", o.stderr)
283+
284+
o = _run(program)
311285
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
312-
self.assertNotIn("Initializing CoordinationService", p.stderr)
286+
self.assertNotIn("Initializing CoordinationService", o.stderr)
313287

314288
if __name__ == "__main__":
315289
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)