diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 3088ced9872a..f8955a673861 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -469,6 +469,32 @@ def on_profile(): thread_profiler.join() self._check_xspace_pb_exist(logdir) + def testDeviceVersionSavedToMetadata(self): + print("testDeviceVersionSavedToMetadata") + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + with jax.profiler.trace(tmpdir): + jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count())) + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + self.assertEqual(len(proto_path), 1) + proto = proto_path[0].read_bytes() + if jtu.test_device_matches(["tpu"]): + libtpu_version_index = proto.find(b"libtpu_version") + print(libtpu_version_index) + self.assertGreater(libtpu_version_index, 0) + print("libtpu_version: ", libtpu_version_index) + print(proto[libtpu_version_index:libtpu_version_index + 100]) + self.assertIn(b"libtpu_version", proto) + if jtu.test_device_matches(["gpu"]): + cuda_version_index = proto.find(b"cuda_version") + print(cuda_version_index) + self.assertGreater(cuda_version_index, 0) + print("cuda_version: ", cuda_version_index) + print(proto[cuda_version_index:cuda_version_index + 100]) + self.assertIn(b"cuda_version", proto) + @unittest.skip("Profiler takes >30s on Cloud TPUs") @unittest.skipIf( not (portpicker and _pywrap_profiler_plugin),