@@ -60,11 +60,14 @@ def mean(x, n):
6060"""
6161
6262import argparse
63+ import datetime
6364import inspect
6465import json
66+ import platform
6567import timeit
6668from ast import literal_eval
6769from functools import partial
70+ from importlib .metadata import PackageNotFoundError , version
6871from itertools import product
6972from pathlib import Path
7073
@@ -80,6 +83,69 @@ class SkipBenchmarkException(Exception):
8083 """Exception to be raised to skip benchmark for some parameter set."""
8184
8285
86+ def _get_version_or_none (package_name ):
87+ """Get installed version of package or `None` if package not found."""
88+ try :
89+ return version (package_name )
90+ except PackageNotFoundError :
91+ return None
92+
93+
94+ def _get_cpu_info ():
95+ """Get details of CPU from cpuinfo if available or None if not."""
96+ try :
97+ import cpuinfo
98+
99+ return cpuinfo .get_cpu_info ()
100+ except ImportError :
101+ return None
102+
103+
104+ def _get_gpu_memory_mebibytes (device ):
105+ """Try to get GPU memory available in mebibytes (MiB)."""
106+ memory_stats = device .memory_stats ()
107+ if memory_stats is None :
108+ return None
109+ bytes_limit = memory_stats .get ("bytes_limit" )
110+ return bytes_limit // 2 ** 20 if bytes_limit is not None else None
111+
112+
113+ def _get_gpu_info ():
114+ """Get details of GPU devices available from JAX or None if JAX not available."""
115+ try :
116+ import jax
117+
118+ return [
119+ {
120+ "kind" : d .device_kind ,
121+ "memory_available / MiB" : _get_gpu_memory_mebibytes (d ),
122+ }
123+ for d in jax .devices ()
124+ if d .platform == "gpu"
125+ ]
126+ except ImportError :
127+ return None
128+
129+
130+ def _get_cuda_info ():
131+ """Try to get information on versions of CUDA libraries."""
132+ try :
133+ from jax ._src .lib import cuda_versions
134+
135+ if cuda_versions is None :
136+ return None
137+ return {
138+ "cuda_runtime_version" : cuda_versions .cuda_runtime_get_version (),
139+ "cuda_runtime_build_version" : cuda_versions .cuda_runtime_build_version (),
140+ "cudnn_version" : cuda_versions .cudnn_get_version (),
141+ "cudnn_build_version" : cuda_versions .cudnn_build_version (),
142+ "cufft_version" : cuda_versions .cufft_get_version (),
143+ "cufft_build_version" : cuda_versions .cufft_build_version (),
144+ }
145+ except ImportError :
146+ return None
147+
148+
83149def skip (message ):
84150 """Skip benchmark for a particular parameter set with explanatory message.
85151
@@ -145,9 +211,11 @@ def _parse_parameter_overrides(parameter_overrides):
145211 )
146212
147213
148- def _parse_cli_arguments ():
214+ def _parse_cli_arguments (description ):
149215 """Parse command line arguments passed for controlling benchmark runs"""
150- parser = argparse .ArgumentParser ("Run benchmarks" )
216+ parser = argparse .ArgumentParser (
217+ description = description , formatter_class = argparse .ArgumentDefaultsHelpFormatter
218+ )
151219 parser .add_argument (
152220 "-number-runs" ,
153221 type = int ,
@@ -174,6 +242,15 @@ def _parse_cli_arguments():
174242 parser .add_argument (
175243 "-output-file" , type = Path , help = "File path to write JSON formatted results to."
176244 )
245+ parser .add_argument (
246+ "--run-once-and-discard" ,
247+ action = "store_true" ,
248+ help = (
249+ "Run benchmark function once first without recording time to "
250+ "ignore the effect of any initial one-off costs such as just-in-time "
251+ "compilation."
252+ ),
253+ )
177254 return parser .parse_args ()
178255
179256
@@ -204,6 +281,7 @@ def run_benchmarks(
204281 number_repeats ,
205282 print_results = True ,
206283 parameter_overrides = None ,
284+ run_once_and_discard = False ,
207285):
208286 """Run a set of benchmarks.
209287
@@ -217,14 +295,17 @@ def run_benchmarks(
217295 print_results: Whether to print benchmark results to stdout.
218296 parameter_overrides: Dictionary specifying any overrides for parameter values
219297 set in `benchmark` decorator.
298+ run_once_and_discard: Whether to run benchmark function once first without
299+ recording time to ignore the effect of any initial one-off costs such as
300+ just-in-time compilation.
220301
221302 Returns:
222303 Dictionary containing timing (and potentially memory usage) results for each
223304 parameters set of each benchmark function.
224305 """
225306 results = {}
226307 for benchmark in benchmarks :
227- results [benchmark .__name__ ] = {}
308+ results [benchmark .__name__ ] = []
228309 if print_results :
229310 print (benchmark .__name__ )
230311 parameters = benchmark .parameters .copy ()
@@ -234,13 +315,15 @@ def run_benchmarks(
234315 try :
235316 precomputes = benchmark .setup (** parameter_set )
236317 benchmark_function = partial (benchmark , ** precomputes , ** parameter_set )
318+ if run_once_and_discard :
319+ benchmark_function ()
237320 run_times = [
238321 time / number_runs
239322 for time in timeit .repeat (
240323 benchmark_function , number = number_runs , repeat = number_repeats
241324 )
242325 ]
243- results [ benchmark . __name__ ] = {** parameter_set , "times / s" : run_times }
326+ results_entry = {** parameter_set , "times / s" : run_times }
244327 if MEMORY_PROFILER_AVAILABLE :
245328 baseline_memory = memory_profiler .memory_usage (max_usage = True )
246329 peak_memory = (
@@ -253,7 +336,8 @@ def run_benchmarks(
253336 )
254337 - baseline_memory
255338 )
256- results [benchmark .__name__ ]["peak_memory / MiB" ] = peak_memory
339+ results_entry ["peak_memory / MiB" ] = peak_memory
340+ results [benchmark .__name__ ].append (results_entry )
257341 if print_results :
258342 print (
259343 (
@@ -262,9 +346,9 @@ def run_benchmarks(
262346 else " "
263347 )
264348 + f"min(time): { min (run_times ):>#7.2g} s, "
265- + f"max(time): { max (run_times ):>#7.2g} s, "
349+ + f"max(time): { max (run_times ):>#7.2g} s"
266350 + (
267- f"peak mem.: { peak_memory :>#7.2g} MiB"
351+ f", peak mem.: { peak_memory :>#7.2g} MiB"
268352 if MEMORY_PROFILER_AVAILABLE
269353 else ""
270354 )
@@ -288,18 +372,42 @@ def parse_args_collect_and_run_benchmarks(module=None):
288372 Dictionary containing timing (and potentially memory usage) results for each
289373 parameters set of each benchmark function.
290374 """
291- args = _parse_cli_arguments ()
292- parameter_overrides = _parse_parameter_overrides (args .parameter_overrides )
293375 if module is None :
294376 frame = inspect .stack ()[1 ]
295377 module = inspect .getmodule (frame [0 ])
378+ args = _parse_cli_arguments (module .__doc__ )
379+ parameter_overrides = _parse_parameter_overrides (args .parameter_overrides )
296380 results = run_benchmarks (
297381 benchmarks = collect_benchmarks (module ),
298382 number_runs = args .number_runs ,
299383 number_repeats = args .repeats ,
300384 parameter_overrides = parameter_overrides ,
385+ run_once_and_discard = args .run_once_and_discard ,
301386 )
302387 if args .output_file is not None :
388+ package_versions = {
389+ f"{ package } _version" : _get_version_or_none (package )
390+ for package in ("s2fft" , "jax" , "numpy" )
391+ }
392+ system_info = {
393+ "architecture" : platform .architecture (),
394+ "machine" : platform .machine (),
395+ "node" : platform .node (),
396+ "processor" : platform .processor (),
397+ "python_version" : platform .python_version (),
398+ "release" : platform .release (),
399+ "system" : platform .system (),
400+ "cpu_info" : _get_cpu_info (),
401+ "gpu_info" : _get_gpu_info (),
402+ "cuda_info" : _get_cuda_info (),
403+ ** package_versions ,
404+ }
303405 with open (args .output_file , "w" ) as f :
304- json .dump (results , f )
406+ output = {
407+ "date_time" : datetime .datetime .now ().isoformat (),
408+ "benchmark_module" : module .__name__ ,
409+ "system_info" : system_info ,
410+ "results" : results ,
411+ }
412+ json .dump (output , f , indent = True )
305413 return results
0 commit comments