Skip to content

Commit 1a6cce0

Browse files
committed
Add type hints to benchmarking module
1 parent 651c14f commit 1a6cce0

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

benchmarks/benchmarking.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,12 @@ def mean(x, n):
8383
import timeit
8484
import tracemalloc
8585
from ast import literal_eval
86+
from collections.abc import Callable, Iterable
8687
from functools import partial
8788
from importlib.metadata import PackageNotFoundError, version
8889
from itertools import product
8990
from pathlib import Path
91+
from types import ModuleType
9092
from typing import Any, NamedTuple
9193

9294
import jax
@@ -97,7 +99,7 @@ class SkipBenchmarkException(Exception):
9799
"""Exception to be raised to skip benchmark for some parameter set."""
98100

99101

100-
def _get_version_or_none(package_name):
102+
def _get_version_or_none(package_name: str) -> str | None:
101103
"""Get installed version of package or `None` if package not found."""
102104
try:
103105
return version(package_name)
@@ -115,15 +117,15 @@ def _get_cpu_info():
115117
return None
116118

117119

118-
def _get_gpu_memory_in_bytes(device):
120+
def _get_gpu_memory_in_bytes(device: jax.Device) -> int | None:
119121
"""Try to get GPU memory available in bytes."""
120122
memory_stats = device.memory_stats()
121123
if memory_stats is None:
122124
return None
123125
return memory_stats.get("bytes_limit")
124126

125127

126-
def _get_gpu_info():
128+
def _get_gpu_info() -> dict[str, str | int]:
127129
"""Get details of GPU devices available from JAX or None if JAX not available."""
128130
try:
129131
import jax
@@ -140,7 +142,7 @@ def _get_gpu_info():
140142
return None
141143

142144

143-
def _get_cuda_info():
145+
def _get_cuda_info() -> dict[str, str]:
144146
"""Try to get information on versions of CUDA libraries."""
145147
try:
146148
from jax._src.lib import cuda_versions
@@ -159,7 +161,7 @@ def _get_cuda_info():
159161
return None
160162

161163

162-
def skip(message):
164+
def skip(message: str) -> None:
163165
"""Skip benchmark for a particular parameter set with explanatory message.
164166
165167
Args:
@@ -176,17 +178,22 @@ class BenchmarkSetup(NamedTuple):
176178
jit_benchmark: bool = False
177179

178180

179-
def benchmark(setup=None, **parameters):
181+
def benchmark(
182+
setup: Callable[..., BenchmarkSetup] | None = None, **parameters
183+
) -> Callable:
180184
"""Decorator for defining a function to be benchmark.
181185
182186
Args:
183187
setup: Function performing any necessary set up for benchmark, and the resource
184188
usage of which will not be tracked in benchmarking. The function should
185-
return a 2-tuple, with first entry a dictionary of values to pass to the
186-
benchmark as keyword arguments, corresponding to any precomputed values,
187-
and the second entry optionally a reference value specifying the expected
188-
'true' numerical output of the behchmarked function to allow computing
189-
numerical error, or `None` if there is no relevant reference value.
189+
return an instance of `BenchmarkSetup` named tuple, with first entry a
190+
dictionary of values to pass to the benchmark as keyword arguments,
191+
corresponding to any precomputed values, the second entry optionally a
192+
reference value specifying the expected 'true' numerical output of the
193+
behchmarked function to allow computing numerical error, or `None` if there
194+
is no relevant reference value and third entry a boolean flag indicating
195+
whether to use JAX's just-in-time compilation transform to benchmark
196+
function.
190197
191198
Kwargs:
192199
Parameter names and associated lists of values over which to run benchmark.
@@ -209,12 +216,12 @@ def decorator(function):
209216
return decorator
210217

211218

212-
def _parameters_string(parameters):
219+
def _parameters_string(parameters: dict) -> str:
213220
"""Format parameter values as string for printing benchmark results."""
214221
return "(" + ", ".join(f"{name}: {val}" for name, val in parameters.items()) + ")"
215222

216223

217-
def _format_results_entry(results_entry):
224+
def _format_results_entry(results_entry: dict) -> str:
218225
"""Format benchmark results entry as a string for printing."""
219226
return (
220227
(
@@ -243,20 +250,20 @@ def _format_results_entry(results_entry):
243250
)
244251

245252

246-
def _dict_product(dicts):
253+
def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]:
247254
"""Generator corresponding to Cartesian product of dictionaries."""
248255
return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values()))
249256

250257

251-
def _parse_value(value):
258+
def _parse_value(value: str) -> Any:
252259
"""Parse a value passed at command line as a Python literal or string as fallback"""
253260
try:
254261
return literal_eval(value)
255262
except ValueError:
256263
return str(value)
257264

258265

259-
def _parse_parameter_overrides(parameter_overrides):
266+
def _parse_parameter_overrides(parameter_overrides: list[str]) -> dict[str, Any]:
260267
"""Parse any parameter override values passed as command line arguments"""
261268
return (
262269
{
@@ -268,7 +275,7 @@ def _parse_parameter_overrides(parameter_overrides):
268275
)
269276

270277

271-
def _parse_cli_arguments(description):
278+
def _parse_cli_arguments(description: str) -> argparse.Namespace:
272279
"""Parse command line arguments passed for controlling benchmark runs"""
273280
parser = argparse.ArgumentParser(
274281
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -307,7 +314,7 @@ def _parse_cli_arguments(description):
307314
return parser.parse_args()
308315

309316

310-
def _is_benchmark(object):
317+
def _is_benchmark(object: Any) -> bool:
311318
"""Predicate for testing whether an object is a benchmark function or not."""
312319
return (
313320
inspect.isfunction(object)
@@ -316,7 +323,9 @@ def _is_benchmark(object):
316323
)
317324

318325

319-
def collect_benchmarks(module, benchmark_names):
326+
def collect_benchmarks(
327+
module: ModuleType, benchmark_names: list[str]
328+
) -> list[Callable]:
320329
"""Collect all benchmark functions from a module.
321330
322331
Args:
@@ -335,7 +344,7 @@ def collect_benchmarks(module, benchmark_names):
335344

336345

337346
@contextlib.contextmanager
338-
def trace_memory_allocations(n_frames=1):
347+
def trace_memory_allocations(n_frames: int = 1) -> Callable[[], tuple[int, int]]:
339348
"""Context manager for tracing memory allocations in managed with block.
340349
341350
Returns a thunk (zero-argument function) which can be called on exit from with block
@@ -357,7 +366,9 @@ def trace_memory_allocations(n_frames=1):
357366
tracemalloc.stop()
358367

359368

360-
def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry):
369+
def _compile_jax_benchmark_and_analyse(
370+
benchmark_function: Callable, results_entry: dict
371+
) -> Callable:
361372
"""Compile a JAX benchmark function and extract cost estimates if available."""
362373
compiled_benchmark_function = jax.jit(benchmark_function).lower().compile()
363374
cost_analysis = compiled_benchmark_function.cost_analysis()
@@ -385,12 +396,12 @@ def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry):
385396

386397

387398
def run_benchmarks(
388-
benchmarks,
389-
number_runs,
390-
number_repeats,
391-
print_results=True,
392-
parameter_overrides=None,
393-
):
399+
benchmarks: list[Callable],
400+
number_runs: int,
401+
number_repeats: int,
402+
print_results: bool = True,
403+
parameter_overrides: dict[str, Any] | None = None,
404+
) -> dict[str, Any]:
394405
"""Run a set of benchmarks.
395406
396407
Args:
@@ -456,7 +467,7 @@ def run_benchmarks(
456467
return results
457468

458469

459-
def get_system_info():
470+
def get_system_info() -> dict[str, Any]:
460471
"""Get dictionary of metadata about system.
461472
462473
Returns:
@@ -482,7 +493,9 @@ def get_system_info():
482493
}
483494

484495

485-
def write_json_results_file(output_file, results, benchmark_module):
496+
def write_json_results_file(
497+
output_file: Path, results: dict[str, Any], benchmark_module: str
498+
) -> None:
486499
"""Write benchmark results and system information to a file in JSON format.
487500
488501
Args:
@@ -500,7 +513,7 @@ def write_json_results_file(output_file, results, benchmark_module):
500513
json.dump(output, f, indent=True)
501514

502515

503-
def parse_args_collect_and_run_benchmarks(module=None):
516+
def parse_args_collect_and_run_benchmarks(module: ModuleType | None = None) -> None:
504517
"""Collect and run all benchmarks in a module and parse command line arguments.
505518
506519
Args:

0 commit comments

Comments
 (0)