@@ -469,6 +469,32 @@ def on_profile():
469469 thread_profiler .join ()
470470 self ._check_xspace_pb_exist (logdir )
471471
472+ def testDeviceVersionSavedToMetadata (self ):
473+ print ("testDeviceVersionSavedToMetadata" )
474+ with tempfile .TemporaryDirectory () as tmpdir_string :
475+ tmpdir = pathlib .Path (tmpdir_string )
476+ with jax .profiler .trace (tmpdir ):
477+ jax .pmap (lambda x : jax .lax .psum (x + 1 , 'i' ), axis_name = 'i' )(
478+ jnp .ones (jax .local_device_count ()))
479+
480+ proto_path = tuple (tmpdir .rglob ("*.xplane.pb" ))
481+ self .assertEqual (len (proto_path ), 1 )
482+ proto = proto_path [0 ].read_bytes ()
483+ if jtu .test_device_matches (["tpu" ]):
484+ libtpu_version_index = proto .find (b"libtpu_version" )
485+ print (libtpu_version_index )
486+ self .assertGreater (libtpu_version_index , 0 )
487+ print ("libtpu_version: " , libtpu_version_index )
488+ print (proto [libtpu_version_index :libtpu_version_index + 100 ])
489+ self .assertIn (b"libtpu_version" , proto )
490+ if jtu .test_device_matches (["gpu" ]):
491+ cuda_version_index = proto .find (b"cuda_version" )
492+ print (cuda_version_index )
493+ self .assertGreater (cuda_version_index , 0 )
494+ print ("cuda_version: " , cuda_version_index )
495+ print (proto [cuda_version_index :cuda_version_index + 100 ])
496+ self .assertIn (b"cuda_version" , proto )
497+
472498 @unittest .skip ("Profiler takes >30s on Cloud TPUs" )
473499 @unittest .skipIf (
474500 not (portpicker and _pywrap_profiler_plugin ),
0 commit comments