@@ -101,12 +101,47 @@ def _get_cpu_info():
101101 return None
102102
103103
104+ def _get_gpu_memory_mebibytes (device ):
105+ """Try to get GPU memory available in mebibytes (MiB)."""
106+ memory_stats = device .memory_stats ()
107+ if memory_stats is None :
108+ return None
109+ bytes_limit = memory_stats .get ("bytes_limit" )
110+ return bytes_limit // 2 ** 20 if bytes_limit is not None else None
111+
112+
104113def _get_gpu_info ():
105114 """Get details of GPU devices available from JAX or None if JAX not available."""
106115 try :
107116 import jax
108117
109- return [d .device_kind for d in jax .devices () if d .platform == "gpu" ]
118+ return [
119+ {
120+ "kind" : d .device_kind ,
121+ "memory_available / MiB" : _get_gpu_memory_mebibytes (d ),
122+ }
123+ for d in jax .devices ()
124+ if d .platform == "gpu"
125+ ]
126+ except ImportError :
127+ return None
128+
129+
130+ def _get_cuda_info ():
131+ """Try to get information on versions of CUDA libraries."""
132+ try :
133+ from jax ._src .lib import cuda_versions
134+
135+ if cuda_versions is None :
136+ return None
137+ return {
138+ "cuda_runtime_version" : cuda_versions .cuda_runtime_get_version (),
139+ "cuda_runtime_build_version" : cuda_versions .cuda_runtime_build_version (),
140+ "cudnn_version" : cuda_versions .cudnn_get_version (),
141+ "cudnn_build_version" : cuda_versions .cudnn_build_version (),
142+ "cufft_version" : cuda_versions .cufft_get_version (),
143+ "cufft_build_version" : cuda_versions .cufft_build_version (),
144+ }
110145 except ImportError :
111146 return None
112147
@@ -364,6 +399,7 @@ def parse_args_collect_and_run_benchmarks(module=None):
364399 "system" : platform .system (),
365400 "cpu_info" : _get_cpu_info (),
366401 "gpu_info" : _get_gpu_info (),
402+ "cuda_info" : _get_cuda_info (),
367403 ** package_versions ,
368404 }
369405 with open (args .output_file , "w" ) as f :
0 commit comments