Skip to content

Commit 197c6da

Browse files
Add metadata for CUDA and libtpu versions
PiperOrigin-RevId: 834487912
1 parent 83d0952 commit 197c6da

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/profiler_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,32 @@ def on_profile():
469469
thread_profiler.join()
470470
self._check_xspace_pb_exist(logdir)
471471

472+
def testDeviceVersionSavedToMetadata(self):
473+
print("testDeviceVersionSavedToMetadata")
474+
with tempfile.TemporaryDirectory() as tmpdir_string:
475+
tmpdir = pathlib.Path(tmpdir_string)
476+
with jax.profiler.trace(tmpdir):
477+
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
478+
jnp.ones(jax.local_device_count()))
479+
480+
proto_path = tuple(tmpdir.rglob("*.xplane.pb"))
481+
self.assertEqual(len(proto_path), 1)
482+
proto = proto_path[0].read_bytes()
483+
if jtu.test_device_matches(["tpu"]):
484+
libtpu_version_index = proto.find(b"libtpu_version")
485+
print(libtpu_version_index)
486+
self.assertGreater(libtpu_version_index, 0)
487+
print("libtpu_version: ", libtpu_version_index)
488+
print(proto[libtpu_version_index:libtpu_version_index + 100])
489+
self.assertIn(b"libtpu_version", proto)
490+
if jtu.test_device_matches(["gpu"]):
491+
cuda_version_index = proto.find(b"cuda_version")
492+
print(cuda_version_index)
493+
self.assertGreater(cuda_version_index, 0)
494+
print("cuda_version: ", cuda_version_index)
495+
print(proto[cuda_version_index:cuda_version_index + 100])
496+
self.assertIn(b"cuda_version", proto)
497+
472498
@unittest.skip("Profiler takes >30s on Cloud TPUs")
473499
@unittest.skipIf(
474500
not (portpicker and _pywrap_profiler_plugin),

0 commit comments

Comments
 (0)