Skip to content

Commit 742fdad

Browse files
committed
Record GPU memory + CUDA info in results
1 parent e917f05 commit 742fdad

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

benchmarks/benchmarking.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,47 @@ def _get_cpu_info():
101101
return None
102102

103103

104+
def _get_gpu_memory_mebibytes(device):
105+
"""Try to get GPU memory available in mebibytes (MiB)."""
106+
memory_stats = device.memory_stats()
107+
if memory_stats is None:
108+
return None
109+
bytes_limit = memory_stats.get("bytes_limit")
110+
return bytes_limit // 2**20 if bytes_limit is not None else None
111+
112+
104113
def _get_gpu_info():
105114
"""Get details of GPU devices available from JAX or None if JAX not available."""
106115
try:
107116
import jax
108117

109-
return [d.device_kind for d in jax.devices() if d.platform == "gpu"]
118+
return [
119+
{
120+
"kind": d.device_kind,
121+
"memory_available / MiB": _get_gpu_memory_mebibytes(d),
122+
}
123+
for d in jax.devices()
124+
if d.platform == "gpu"
125+
]
126+
except ImportError:
127+
return None
128+
129+
130+
def _get_cuda_info():
131+
"""Try to get information on versions of CUDA libraries."""
132+
try:
133+
from jax._src.lib import cuda_versions
134+
135+
if cuda_versions is None:
136+
return None
137+
return {
138+
"cuda_runtime_version": cuda_versions.cuda_runtime_get_version(),
139+
"cuda_runtime_build_version": cuda_versions.cuda_runtime_build_version(),
140+
"cudnn_version": cuda_versions.cudnn_get_version(),
141+
"cudnn_build_version": cuda_versions.cudnn_build_version(),
142+
"cufft_version": cuda_versions.cufft_get_version(),
143+
"cufft_build_version": cuda_versions.cufft_build_version(),
144+
}
110145
except ImportError:
111146
return None
112147

@@ -364,6 +399,7 @@ def parse_args_collect_and_run_benchmarks(module=None):
364399
"system": platform.system(),
365400
"cpu_info": _get_cpu_info(),
366401
"gpu_info": _get_gpu_info(),
402+
"cuda_info": _get_cuda_info(),
367403
**package_versions,
368404
}
369405
with open(args.output_file, "w") as f:

0 commit comments

Comments
 (0)