Skip to content

Commit 1791828

Browse files
committed
Normalize benchmark results key naming
1 parent 4e01158 commit 1791828

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

benchmarks/benchmarking.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,12 @@ def _get_cpu_info():
120120
return None
121121

122122

123-
def _get_gpu_memory_mebibytes(device):
124-
"""Try to get GPU memory available in mebibytes (MiB)."""
123+
def _get_gpu_memory_in_bytes(device):
124+
"""Try to get GPU memory available in bytes."""
125125
memory_stats = device.memory_stats()
126126
if memory_stats is None:
127127
return None
128-
bytes_limit = memory_stats.get("bytes_limit")
129-
return bytes_limit // 2**20 if bytes_limit is not None else None
128+
return memory_stats.get("bytes_limit")
130129

131130

132131
def _get_gpu_info():
@@ -137,7 +136,7 @@ def _get_gpu_info():
137136
return [
138137
{
139138
"kind": d.device_kind,
140-
"memory_available / MiB": _get_gpu_memory_mebibytes(d),
139+
"memory_available_in_bytes": _get_gpu_memory_in_bytes(d),
141140
}
142141
for d in jax.devices()
143142
if d.platform == "gpu"
@@ -228,15 +227,15 @@ def _format_results_entry(results_entry):
228227
if len(results_entry["parameters"]) != 0
229228
else " "
230229
)
231-
+ f"min(time): {min(results_entry['times / s']):>#7.2g}s, "
232-
+ f"max(time): {max(results_entry['times / s']):>#7.2g}s"
230+
+ f"min(run times): {min(results_entry['run_times_in_seconds']):>#7.2g}s, "
231+
+ f"max(run times): {max(results_entry['run_times_in_seconds']):>#7.2g}s"
233232
+ (
234-
f", peak memory: {results_entry['peak_memory / MiB']:>#7.2g}MiB"
235-
if "peak_memory / MiB" in results_entry
233+
f", peak memory: {results_entry['peak_memory_in_bytes']:>#7.2g}B"
234+
if "peak_memory_in_bytes" in results_entry
236235
else ""
237236
)
238237
+ (
239-
f", max(abs(error)): {results_entry['error']:>#7.2g}"
238+
f", max(abs(error)): {results_entry['max_abs_error']:>#7.2g}"
240239
if "error" in results_entry
241240
else ""
242241
)
@@ -441,19 +440,21 @@ def run_benchmarks(
441440
# computing numerical error
442441
output = benchmark_function()
443442
if reference_output is not None and output is not None:
444-
results_entry["error"] = abs(reference_output - output).max()
443+
results_entry["max_abs_error"] = abs(
444+
reference_output - output
445+
).max()
445446
run_times = [
446447
time / number_runs
447448
for time in timeit.repeat(
448449
benchmark_function, number=number_runs, repeat=number_repeats
449450
)
450451
]
451-
results_entry["times / s"] = run_times
452+
results_entry["run_times_in_seconds"] = run_times
452453
if MEMORY_PROFILER_AVAILABLE:
453-
results_entry["peak_memory / MiB"] = measure_peak_memory_usage(
454+
results_entry["peak_memory_in_bytes"] = measure_peak_memory_usage(
454455
benchmark_function,
455456
interval=min(run_times) / 20,
456-
)
457+
) * (2**20)
457458
results[benchmark.__name__].append(results_entry)
458459
if print_results:
459460
print(_format_results_entry(results_entry))

0 commit comments

Comments
 (0)