@@ -120,13 +120,12 @@ def _get_cpu_info():
120120 return None
121121
122122
123- def _get_gpu_memory_mebibytes (device ):
124- """Try to get GPU memory available in mebibytes (MiB) ."""
123+ def _get_gpu_memory_in_bytes (device ):
124+ """Try to get GPU memory available in bytes ."""
125125 memory_stats = device .memory_stats ()
126126 if memory_stats is None :
127127 return None
128- bytes_limit = memory_stats .get ("bytes_limit" )
129- return bytes_limit // 2 ** 20 if bytes_limit is not None else None
128+ return memory_stats .get ("bytes_limit" )
130129
131130
132131def _get_gpu_info ():
@@ -137,7 +136,7 @@ def _get_gpu_info():
137136 return [
138137 {
139138 "kind" : d .device_kind ,
140- "memory_available / MiB " : _get_gpu_memory_mebibytes (d ),
139+ "memory_available_in_bytes " : _get_gpu_memory_in_bytes (d ),
141140 }
142141 for d in jax .devices ()
143142 if d .platform == "gpu"
@@ -228,15 +227,15 @@ def _format_results_entry(results_entry):
228227 if len (results_entry ["parameters" ]) != 0
229228 else " "
230229 )
231- + f"min(time ): { min (results_entry ['times / s ' ]):>#7.2g} s, "
232- + f"max(time ): { max (results_entry ['times / s ' ]):>#7.2g} s"
230+ + f"min(run times ): { min (results_entry ['run_times_in_seconds ' ]):>#7.2g} s, "
231+ + f"max(run times ): { max (results_entry ['run_times_in_seconds ' ]):>#7.2g} s"
233232 + (
234- f", peak memory: { results_entry ['peak_memory / MiB ' ]:>#7.2g} MiB "
235- if "peak_memory / MiB " in results_entry
233+ f", peak memory: { results_entry ['peak_memory_in_bytes ' ]:>#7.2g} B "
234+ if "peak_memory_in_bytes " in results_entry
236235 else ""
237236 )
238237 + (
239- f", max(abs(error)): { results_entry ['error ' ]:>#7.2g} "
238+ f", max(abs(error)): { results_entry ['max_abs_error ' ]:>#7.2g} "
240239 if "error" in results_entry
241240 else ""
242241 )
@@ -441,19 +440,21 @@ def run_benchmarks(
441440 # computing numerical error
442441 output = benchmark_function ()
443442 if reference_output is not None and output is not None :
444- results_entry ["error" ] = abs (reference_output - output ).max ()
443+ results_entry ["max_abs_error" ] = abs (
444+ reference_output - output
445+ ).max ()
445446 run_times = [
446447 time / number_runs
447448 for time in timeit .repeat (
448449 benchmark_function , number = number_runs , repeat = number_repeats
449450 )
450451 ]
451- results_entry ["times / s " ] = run_times
452+ results_entry ["run_times_in_seconds " ] = run_times
452453 if MEMORY_PROFILER_AVAILABLE :
453- results_entry ["peak_memory / MiB " ] = measure_peak_memory_usage (
454+ results_entry ["peak_memory_in_bytes " ] = measure_peak_memory_usage (
454455 benchmark_function ,
455456 interval = min (run_times ) / 20 ,
456- )
457+ ) * ( 2 ** 20 )
457458 results [benchmark .__name__ ].append (results_entry )
458459 if print_results :
459460 print (_format_results_entry (results_entry ))
0 commit comments