diff --git a/benchmarks/README.md b/benchmarks/README.md index bbcd39f3..ab449949 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -8,15 +8,24 @@ and/or systems. If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler) is installed an estimate of the peak (main) memory usage of the benchmarked functions will also be recorded. +If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/) +is installed additional information about CPU of system benchmarks are run on will be +recorded in JSON output. ## Description The benchmark scripts are as follows: - * `wigner.py` contains benchmarks for Wigner transforms (forward and inverse) - * `spherical.py` contains benchmarks for spherical transforms (forward and inverse) - + * `spherical.py` contains benchmarks for on-the-fly implementations of spherical + transforms (forward and inverse). + * `precompute_spherical.py` contains benchmarks for precompute implementations of + spherical transforms (forward and inverse). + * `wigner.py` contains benchmarks for on-the-fly implementations of Wigner-d + transforms (forward and inverse). + * `precompute_wigner.py` contains benchmarks for precompute implementations of + Wigner-d transforms (forward and inverse). + The `benchmarking.py` module contains shared utility functions for defining and running the benchmarks. @@ -29,22 +38,26 @@ the JSON formatted benchmark results to. Pass a `--help` argument to the script display the usage message: ``` -usage: Run benchmarks [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] - [-parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]]] - [-output-file OUTPUT_FILE] +usage: spherical.py [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] + [-parameter-overrides [PARAMETER_OVERRIDES ...]] [-output-file OUTPUT_FILE] + [--run-once-and-discard] + +Benchmarks for on-the-fly spherical transforms. -optional arguments: +options: -h, --help show this help message and exit -number-runs NUMBER_RUNS - Number of times to run the benchmark in succession in each - timing run. - -repeats REPEATS Number of times to repeat the benchmark runs. - -parameter-overrides [PARAMETER_OVERRIDES [PARAMETER_OVERRIDES ...]] - Override for values to use for benchmark parameter. A parameter - name followed by space separated list of values to use. May be - specified multiple times to override multiple parameters. + Number of times to run the benchmark in succession in each timing run. (default: 10) + -repeats REPEATS Number of times to repeat the benchmark runs. (default: 3) + -parameter-overrides [PARAMETER_OVERRIDES ...] + Override for values to use for benchmark parameter. A parameter name followed by space + separated list of values to use. May be specified multiple times to override multiple + parameters. (default: None) -output-file OUTPUT_FILE - File path to write JSON formatted results to. + File path to write JSON formatted results to. (default: None) + --run-once-and-discard + Run benchmark function once first without recording time to ignore the effect of any initial + one-off costs such as just-in-time compilation. (default: False) ``` For example to run the spherical transform benchmarks using only the JAX implementations, @@ -52,7 +65,7 @@ running on a CPU (in double-precision) for `L` values 64, 128, 256, 512 and 1024 would run from the root of the repository: ```sh -JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py -p L 64 128 256 512 1024 -p method jax +JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py --run-once-and-discard -p L 64 128 256 512 1024 -p method jax ``` Note the usage of environment variables `JAX_PLATFORM_NAME` and `JAX_ENABLE_X64` to diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 5a9496fb..faea7c1b 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -60,11 +60,14 @@ def mean(x, n): """ import argparse +import datetime import inspect import json +import platform import timeit from ast import literal_eval from functools import partial +from importlib.metadata import PackageNotFoundError, version from itertools import product from pathlib import Path @@ -80,6 +83,69 @@ class SkipBenchmarkException(Exception): """Exception to be raised to skip benchmark for some parameter set.""" +def _get_version_or_none(package_name): + """Get installed version of package or `None` if package not found.""" + try: + return version(package_name) + except PackageNotFoundError: + return None + + +def _get_cpu_info(): + """Get details of CPU from cpuinfo if available or None if not.""" + try: + import cpuinfo + + return cpuinfo.get_cpu_info() + except ImportError: + return None + + +def _get_gpu_memory_mebibytes(device): + """Try to get GPU memory available in mebibytes (MiB).""" + memory_stats = device.memory_stats() + if memory_stats is None: + return None + bytes_limit = memory_stats.get("bytes_limit") + return bytes_limit // 2**20 if bytes_limit is not None else None + + +def _get_gpu_info(): + """Get details of GPU devices available from JAX or None if JAX not available.""" + try: + import jax + + return [ + { + "kind": d.device_kind, + "memory_available / MiB": _get_gpu_memory_mebibytes(d), + } + for d in jax.devices() + if d.platform == "gpu" + ] + except ImportError: + return None + + +def _get_cuda_info(): + """Try to get information on versions of CUDA libraries.""" + try: + from jax._src.lib import cuda_versions + + if cuda_versions is None: + return None + return { + "cuda_runtime_version": cuda_versions.cuda_runtime_get_version(), + "cuda_runtime_build_version": cuda_versions.cuda_runtime_build_version(), + "cudnn_version": cuda_versions.cudnn_get_version(), + "cudnn_build_version": cuda_versions.cudnn_build_version(), + "cufft_version": cuda_versions.cufft_get_version(), + "cufft_build_version": cuda_versions.cufft_build_version(), + } + except ImportError: + return None + + def skip(message): """Skip benchmark for a particular parameter set with explanatory message. @@ -145,9 +211,11 @@ def _parse_parameter_overrides(parameter_overrides): ) -def _parse_cli_arguments(): +def _parse_cli_arguments(description): """Parse command line arguments passed for controlling benchmark runs""" - parser = argparse.ArgumentParser("Run benchmarks") + parser = argparse.ArgumentParser( + description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument( "-number-runs", type=int, @@ -174,6 +242,15 @@ def _parse_cli_arguments(): parser.add_argument( "-output-file", type=Path, help="File path to write JSON formatted results to." ) + parser.add_argument( + "--run-once-and-discard", + action="store_true", + help=( + "Run benchmark function once first without recording time to " + "ignore the effect of any initial one-off costs such as just-in-time " + "compilation." + ), + ) return parser.parse_args() @@ -204,6 +281,7 @@ def run_benchmarks( number_repeats, print_results=True, parameter_overrides=None, + run_once_and_discard=False, ): """Run a set of benchmarks. @@ -217,6 +295,9 @@ def run_benchmarks( print_results: Whether to print benchmark results to stdout. parameter_overrides: Dictionary specifying any overrides for parameter values set in `benchmark` decorator. + run_once_and_discard: Whether to run benchmark function once first without + recording time to ignore the effect of any initial one-off costs such as + just-in-time compilation. Returns: Dictionary containing timing (and potentially memory usage) results for each @@ -224,7 +305,7 @@ def run_benchmarks( """ results = {} for benchmark in benchmarks: - results[benchmark.__name__] = {} + results[benchmark.__name__] = [] if print_results: print(benchmark.__name__) parameters = benchmark.parameters.copy() @@ -234,13 +315,15 @@ def run_benchmarks( try: precomputes = benchmark.setup(**parameter_set) benchmark_function = partial(benchmark, **precomputes, **parameter_set) + if run_once_and_discard: + benchmark_function() run_times = [ time / number_runs for time in timeit.repeat( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results[benchmark.__name__] = {**parameter_set, "times / s": run_times} + results_entry = {**parameter_set, "times / s": run_times} if MEMORY_PROFILER_AVAILABLE: baseline_memory = memory_profiler.memory_usage(max_usage=True) peak_memory = ( @@ -253,7 +336,8 @@ def run_benchmarks( ) - baseline_memory ) - results[benchmark.__name__]["peak_memory / MiB"] = peak_memory + results_entry["peak_memory / MiB"] = peak_memory + results[benchmark.__name__].append(results_entry) if print_results: print( ( @@ -262,9 +346,9 @@ def run_benchmarks( else " " ) + f"min(time): {min(run_times):>#7.2g}s, " - + f"max(time): {max(run_times):>#7.2g}s, " + + f"max(time): {max(run_times):>#7.2g}s" + ( - f"peak mem.: {peak_memory:>#7.2g}MiB" + f", peak mem.: {peak_memory:>#7.2g}MiB" if MEMORY_PROFILER_AVAILABLE else "" ) @@ -288,18 +372,42 @@ def parse_args_collect_and_run_benchmarks(module=None): Dictionary containing timing (and potentially memory usage) results for each parameters set of each benchmark function. """ - args = _parse_cli_arguments() - parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) if module is None: frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) + args = _parse_cli_arguments(module.__doc__) + parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) results = run_benchmarks( benchmarks=collect_benchmarks(module), number_runs=args.number_runs, number_repeats=args.repeats, parameter_overrides=parameter_overrides, + run_once_and_discard=args.run_once_and_discard, ) if args.output_file is not None: + package_versions = { + f"{package}_version": _get_version_or_none(package) + for package in ("s2fft", "jax", "numpy") + } + system_info = { + "architecture": platform.architecture(), + "machine": platform.machine(), + "node": platform.node(), + "processor": platform.processor(), + "python_version": platform.python_version(), + "release": platform.release(), + "system": platform.system(), + "cpu_info": _get_cpu_info(), + "gpu_info": _get_gpu_info(), + "cuda_info": _get_cuda_info(), + **package_versions, + } with open(args.output_file, "w") as f: - json.dump(results, f) + output = { + "date_time": datetime.datetime.now().isoformat(), + "benchmark_module": module.__name__, + "system_info": system_info, + "results": results, + } + json.dump(output, f, indent=True) return results diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py new file mode 100644 index 00000000..1af2a2aa --- /dev/null +++ b/benchmarks/precompute_spherical.py @@ -0,0 +1,115 @@ +"""Benchmarks for precompute spherical transforms.""" + +import numpy as np +import pyssht +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip + +import s2fft +import s2fft.precompute_transforms +from s2fft.sampling import s2_samples as samples + +L_VALUES = [8, 16, 32, 64, 128, 256] +SPIN_VALUES = [0] +SAMPLING_VALUES = ["mw"] +METHOD_VALUES = ["numpy", "jax"] +REALITY_VALUES = [True] +RECURSION_VALUES = ["auto"] + + +def setup_forward(method, L, sampling, spin, reality, recursion): + if reality and spin != 0: + skip("Reality only valid for scalar fields (spin=0).") + rng = np.random.default_rng() + flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) + f = pyssht.inverse( + samples.flm_2d_to_1d(flm, L), + L, + Method=sampling.upper(), + Spin=spin, + Reality=reality, + ) + kernel_function = ( + s2fft.precompute_transforms.construct.spin_spherical_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.spin_spherical_kernel + ) + kernel = kernel_function( + L=L, + spin=spin, + reality=reality, + sampling=sampling, + forward=True, + recursion=recursion, + ) + return {"f": f, "kernel": kernel} + + +@benchmark( + setup_forward, + method=METHOD_VALUES, + L=L_VALUES, + sampling=SAMPLING_VALUES, + spin=SPIN_VALUES, + reality=REALITY_VALUES, + recursion=RECURSION_VALUES, +) +def forward(f, kernel, method, L, sampling, spin, reality, recursion): + flm = s2fft.precompute_transforms.spherical.forward( + f=f, + L=L, + spin=spin, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + flm.block_until_ready() + + +def setup_inverse(method, L, sampling, spin, reality, recursion): + if reality and spin != 0: + skip("Reality only valid for scalar fields (spin=0).") + rng = np.random.default_rng() + flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) + kernel_function = ( + s2fft.precompute_transforms.construct.spin_spherical_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.spin_spherical_kernel + ) + kernel = kernel_function( + L=L, + spin=spin, + reality=reality, + sampling=sampling, + forward=False, + recursion=recursion, + ) + return {"flm": flm, "kernel": kernel} + + +@benchmark( + setup_inverse, + method=METHOD_VALUES, + L=L_VALUES, + sampling=SAMPLING_VALUES, + spin=SPIN_VALUES, + reality=REALITY_VALUES, + recursion=RECURSION_VALUES, +) +def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): + f = s2fft.precompute_transforms.spherical.inverse( + flm=flm, + L=L, + spin=spin, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + f.block_until_ready() + + +if __name__ == "__main__": + results = parse_args_collect_and_run_benchmarks() diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py new file mode 100644 index 00000000..01918ca7 --- /dev/null +++ b/benchmarks/precompute_wigner.py @@ -0,0 +1,104 @@ +"""Benchmarks for precompute Wigner-d transforms.""" + +import numpy as np +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks + +import s2fft +import s2fft.precompute_transforms +from s2fft.base_transforms import wigner as base_wigner + +L_VALUES = [16, 32, 64, 128, 256] +N_VALUES = [2] +L_LOWER_VALUES = [0] +SAMPLING_VALUES = ["mw"] +METHOD_VALUES = ["numpy", "jax"] +REALITY_VALUES = [True] +MODE_VALUES = ["auto"] + + +def setup_forward(method, L, N, L_lower, sampling, reality, mode): + rng = np.random.default_rng() + flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) + f = base_wigner.inverse( + flmn, + L, + N, + L_lower=L_lower, + sampling=sampling, + reality=reality, + ) + kernel_function = ( + s2fft.precompute_transforms.construct.wigner_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.wigner_kernel + ) + kernel = kernel_function( + L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode + ) + return {"f": f, "kernel": kernel} + + +@benchmark( + setup_forward, + method=METHOD_VALUES, + L=L_VALUES, + N=N_VALUES, + L_lower=L_LOWER_VALUES, + sampling=SAMPLING_VALUES, + reality=REALITY_VALUES, + mode=MODE_VALUES, +) +def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): + flmn = s2fft.precompute_transforms.wigner.forward( + f=f, + L=L, + N=N, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + flmn.block_until_ready() + + +def setup_inverse(method, L, N, L_lower, sampling, reality, mode): + rng = np.random.default_rng() + flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) + kernel_function = ( + s2fft.precompute_transforms.construct.wigner_kernel_jax + if method == "jax" + else s2fft.precompute_transforms.construct.wigner_kernel + ) + kernel = kernel_function( + L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode + ) + return {"flmn": flmn, "kernel": kernel} + + +@benchmark( + setup_inverse, + method=METHOD_VALUES, + L=L_VALUES, + N=N_VALUES, + L_lower=L_LOWER_VALUES, + sampling=SAMPLING_VALUES, + reality=REALITY_VALUES, + mode=MODE_VALUES, +) +def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): + f = s2fft.precompute_transforms.wigner.inverse( + flmn=flmn, + L=L, + N=N, + kernel=kernel, + sampling=sampling, + reality=reality, + method=method, + ) + if method == "jax": + f.block_until_ready() + + +if __name__ == "__main__": + results = parse_args_collect_and_run_benchmarks() diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index c0f529e0..78231a43 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -1,11 +1,11 @@ -"""Benchmarks for spherical transforms.""" +"""Benchmarks for on-the-fly spherical transforms.""" import numpy as np import pyssht from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip import s2fft -from s2fft.recursions.price_mcewen import generate_precomputes +from s2fft.recursions.price_mcewen import generate_precomputes_jax from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] @@ -17,6 +17,10 @@ SPMD_VALUES = [False] +def _jax_arrays_to_numpy(precomps): + return [np.asarray(p) for p in precomps] + + def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") @@ -31,7 +35,11 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): Spin=spin, Reality=reality, ) - precomps = generate_precomputes(L, spin, sampling, forward=True, L_lower=L_lower) + precomps = generate_precomputes_jax( + L, spin, sampling, forward=True, L_lower=L_lower + ) + if method == "numpy": + precomps = _jax_arrays_to_numpy(precomps) return {"f": f, "precomps": precomps} @@ -71,7 +79,11 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): skip("GPU distribution only valid for JAX.") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - precomps = generate_precomputes(L, spin, sampling, forward=False, L_lower=L_lower) + precomps = generate_precomputes_jax( + L, spin, sampling, forward=False, L_lower=L_lower + ) + if method == "numpy": + precomps = _jax_arrays_to_numpy(precomps) return {"flm": flm, "precomps": precomps} diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index 717501e3..d6180961 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -1,7 +1,7 @@ -"""Benchmarks for Wigner transforms.""" +"""Benchmarks for on-the-fly Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import benchmark, parse_args_collect_and_run_benchmarks import s2fft from s2fft.base_transforms import wigner as base_wigner @@ -16,12 +16,9 @@ SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] -SPMD_VALUES = [False] -def setup_forward(method, L, L_lower, N, sampling, reality, spmd): - if spmd and method != "jax": - skip("GPU distribution only valid for JAX.") +def setup_forward(method, L, L_lower, N, sampling, reality): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) f = base_wigner.inverse( @@ -51,27 +48,23 @@ def setup_forward(method, L, L_lower, N, sampling, reality, spmd): N=N_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, - spmd=SPMD_VALUES, ) -def forward(f, precomps, method, L, L_lower, N, sampling, reality, spmd): +def forward(f, precomps, method, L, L_lower, N, sampling, reality): flmn = s2fft.transforms.wigner.forward( f=f, L=L, - L_lower=L_lower, N=N, - precomps=precomps, sampling=sampling, - reality=reality, method=method, - spmd=spmd, + reality=reality, + precomps=precomps, + L_lower=L_lower, ) if method == "jax": flmn.block_until_ready() -def setup_inverse(method, L, L_lower, N, sampling, reality, spmd): - if spmd and method != "jax": - skip("GPU distribution only valid for JAX.") +def setup_inverse(method, L, L_lower, N, sampling, reality): rng = np.random.default_rng() flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) generate_precomputes = ( @@ -93,19 +86,17 @@ def setup_inverse(method, L, L_lower, N, sampling, reality, spmd): N=N_VALUES, sampling=SAMPLING_VALUES, reality=REALITY_VALUES, - spmd=SPMD_VALUES, ) -def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality, spmd): - f = s2fft.transforms.spherical.inverse( - flm=flmn, +def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): + f = s2fft.transforms.wigner.inverse( + flmn=flmn, L=L, - L_lower=L_lower, N=N, - precomps=precomps, sampling=sampling, reality=reality, method=method, - spmd=spmd, + precomps=precomps, + L_lower=L_lower, ) if method == "jax": f.block_until_ready()