@@ -83,10 +83,12 @@ def mean(x, n):
8383import timeit
8484import tracemalloc
8585from ast import literal_eval
86+ from collections .abc import Callable , Iterable
8687from functools import partial
8788from importlib .metadata import PackageNotFoundError , version
8889from itertools import product
8990from pathlib import Path
91+ from types import ModuleType
9092from typing import Any , NamedTuple
9193
9294import jax
@@ -97,7 +99,7 @@ class SkipBenchmarkException(Exception):
9799 """Exception to be raised to skip benchmark for some parameter set."""
98100
99101
100- def _get_version_or_none (package_name ) :
102+ def _get_version_or_none (package_name : str ) -> str | None :
101103 """Get installed version of package or `None` if package not found."""
102104 try :
103105 return version (package_name )
@@ -115,15 +117,15 @@ def _get_cpu_info():
115117 return None
116118
117119
118- def _get_gpu_memory_in_bytes (device ) :
120+ def _get_gpu_memory_in_bytes (device : jax . Device ) -> int | None :
119121 """Try to get GPU memory available in bytes."""
120122 memory_stats = device .memory_stats ()
121123 if memory_stats is None :
122124 return None
123125 return memory_stats .get ("bytes_limit" )
124126
125127
126- def _get_gpu_info ():
128+ def _get_gpu_info () -> dict [ str , str | int ] :
127129 """Get details of GPU devices available from JAX or None if JAX not available."""
128130 try :
129131 import jax
@@ -140,7 +142,7 @@ def _get_gpu_info():
140142 return None
141143
142144
143- def _get_cuda_info ():
145+ def _get_cuda_info () -> dict [ str , str ] :
144146 """Try to get information on versions of CUDA libraries."""
145147 try :
146148 from jax ._src .lib import cuda_versions
@@ -159,7 +161,7 @@ def _get_cuda_info():
159161 return None
160162
161163
162- def skip (message ) :
164+ def skip (message : str ) -> None :
163165 """Skip benchmark for a particular parameter set with explanatory message.
164166
165167 Args:
@@ -176,17 +178,22 @@ class BenchmarkSetup(NamedTuple):
176178 jit_benchmark : bool = False
177179
178180
179- def benchmark (setup = None , ** parameters ):
181+ def benchmark (
182+ setup : Callable [..., BenchmarkSetup ] | None = None , ** parameters
183+ ) -> Callable :
180184 """Decorator for defining a function to be benchmark.
181185
182186 Args:
183187 setup: Function performing any necessary set up for benchmark, and the resource
184188 usage of which will not be tracked in benchmarking. The function should
185- return a 2-tuple, with first entry a dictionary of values to pass to the
186- benchmark as keyword arguments, corresponding to any precomputed values,
187- and the second entry optionally a reference value specifying the expected
188- 'true' numerical output of the behchmarked function to allow computing
189- numerical error, or `None` if there is no relevant reference value.
189+ return an instance of `BenchmarkSetup` named tuple, with first entry a
190+ dictionary of values to pass to the benchmark as keyword arguments,
191+ corresponding to any precomputed values, the second entry optionally a
192+ reference value specifying the expected 'true' numerical output of the
193+ behchmarked function to allow computing numerical error, or `None` if there
194+ is no relevant reference value and third entry a boolean flag indicating
195+ whether to use JAX's just-in-time compilation transform to benchmark
196+ function.
190197
191198 Kwargs:
192199 Parameter names and associated lists of values over which to run benchmark.
@@ -209,12 +216,12 @@ def decorator(function):
209216 return decorator
210217
211218
212- def _parameters_string (parameters ) :
219+ def _parameters_string (parameters : dict ) -> str :
213220 """Format parameter values as string for printing benchmark results."""
214221 return "(" + ", " .join (f"{ name } : { val } " for name , val in parameters .items ()) + ")"
215222
216223
217- def _format_results_entry (results_entry ) :
224+ def _format_results_entry (results_entry : dict ) -> str :
218225 """Format benchmark results entry as a string for printing."""
219226 return (
220227 (
@@ -243,20 +250,20 @@ def _format_results_entry(results_entry):
243250 )
244251
245252
246- def _dict_product (dicts ) :
253+ def _dict_product (dicts : dict [ str , Iterable [ Any ]]) -> Iterable [ dict [ str , Any ]] :
247254 """Generator corresponding to Cartesian product of dictionaries."""
248255 return (dict (zip (dicts .keys (), values )) for values in product (* dicts .values ()))
249256
250257
251- def _parse_value (value ) :
258+ def _parse_value (value : str ) -> Any :
252259 """Parse a value passed at command line as a Python literal or string as fallback"""
253260 try :
254261 return literal_eval (value )
255262 except ValueError :
256263 return str (value )
257264
258265
259- def _parse_parameter_overrides (parameter_overrides ) :
266+ def _parse_parameter_overrides (parameter_overrides : list [ str ]) -> dict [ str , Any ] :
260267 """Parse any parameter override values passed as command line arguments"""
261268 return (
262269 {
@@ -268,7 +275,7 @@ def _parse_parameter_overrides(parameter_overrides):
268275 )
269276
270277
271- def _parse_cli_arguments (description ) :
278+ def _parse_cli_arguments (description : str ) -> argparse . Namespace :
272279 """Parse command line arguments passed for controlling benchmark runs"""
273280 parser = argparse .ArgumentParser (
274281 description = description , formatter_class = argparse .ArgumentDefaultsHelpFormatter
@@ -307,7 +314,7 @@ def _parse_cli_arguments(description):
307314 return parser .parse_args ()
308315
309316
310- def _is_benchmark (object ) :
317+ def _is_benchmark (object : Any ) -> bool :
311318 """Predicate for testing whether an object is a benchmark function or not."""
312319 return (
313320 inspect .isfunction (object )
@@ -316,7 +323,9 @@ def _is_benchmark(object):
316323 )
317324
318325
319- def collect_benchmarks (module , benchmark_names ):
326+ def collect_benchmarks (
327+ module : ModuleType , benchmark_names : list [str ]
328+ ) -> list [Callable ]:
320329 """Collect all benchmark functions from a module.
321330
322331 Args:
@@ -335,7 +344,7 @@ def collect_benchmarks(module, benchmark_names):
335344
336345
337346@contextlib .contextmanager
338- def trace_memory_allocations (n_frames = 1 ) :
347+ def trace_memory_allocations (n_frames : int = 1 ) -> Callable [[], tuple [ int , int ]] :
339348 """Context manager for tracing memory allocations in managed with block.
340349
341350 Returns a thunk (zero-argument function) which can be called on exit from with block
@@ -357,7 +366,9 @@ def trace_memory_allocations(n_frames=1):
357366 tracemalloc .stop ()
358367
359368
360- def _compile_jax_benchmark_and_analyse (benchmark_function , results_entry ):
369+ def _compile_jax_benchmark_and_analyse (
370+ benchmark_function : Callable , results_entry : dict
371+ ) -> Callable :
361372 """Compile a JAX benchmark function and extract cost estimates if available."""
362373 compiled_benchmark_function = jax .jit (benchmark_function ).lower ().compile ()
363374 cost_analysis = compiled_benchmark_function .cost_analysis ()
@@ -385,12 +396,12 @@ def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry):
385396
386397
387398def run_benchmarks (
388- benchmarks ,
389- number_runs ,
390- number_repeats ,
391- print_results = True ,
392- parameter_overrides = None ,
393- ):
399+ benchmarks : list [ Callable ] ,
400+ number_runs : int ,
401+ number_repeats : int ,
402+ print_results : bool = True ,
403+ parameter_overrides : dict [ str , Any ] | None = None ,
404+ ) -> dict [ str , Any ] :
394405 """Run a set of benchmarks.
395406
396407 Args:
@@ -456,7 +467,7 @@ def run_benchmarks(
456467 return results
457468
458469
459- def get_system_info ():
470+ def get_system_info () -> dict [ str , Any ] :
460471 """Get dictionary of metadata about system.
461472
462473 Returns:
@@ -482,7 +493,9 @@ def get_system_info():
482493 }
483494
484495
485- def write_json_results_file (output_file , results , benchmark_module ):
496+ def write_json_results_file (
497+ output_file : Path , results : dict [str , Any ], benchmark_module : str
498+ ) -> None :
486499 """Write benchmark results and system information to a file in JSON format.
487500
488501 Args:
@@ -500,7 +513,7 @@ def write_json_results_file(output_file, results, benchmark_module):
500513 json .dump (output , f , indent = True )
501514
502515
503- def parse_args_collect_and_run_benchmarks (module = None ):
516+ def parse_args_collect_and_run_benchmarks (module : ModuleType | None = None ) -> None :
504517 """Collect and run all benchmarks in a module and parse command line arguments.
505518
506519 Args:
0 commit comments