diff --git a/benchmarks/README.md b/benchmarks/README.md index ab449949..bb50c0be 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,18 +1,14 @@ # Benchmarks for `s2fft` -Scripts for benchmarking `s2fft` with `timeit` (and optionally `memory_profiler`). +Scripts for benchmarking `s2fft` transforms. Measures time to compute transforms for grids of parameters settings, optionally outputting the results to a JSON file to allow comparing performance over versions -and/or systems. -If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler) -is installed an estimate of the peak (main) memory usage of the benchmarked functions -will also be recorded. +and/or systems. If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/) is installed additional information about CPU of system benchmarks are run on will be recorded in JSON output. - ## Description The benchmark scripts are as follows: @@ -40,7 +36,7 @@ display the usage message: ``` usage: spherical.py [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] [-parameter-overrides [PARAMETER_OVERRIDES ...]] [-output-file OUTPUT_FILE] - [--run-once-and-discard] + [-benchmarks BENCHMARKS [BENCHMARKS ...]] Benchmarks for on-the-fly spherical transforms. @@ -55,9 +51,8 @@ options: parameters. (default: None) -output-file OUTPUT_FILE File path to write JSON formatted results to. (default: None) - --run-once-and-discard - Run benchmark function once first without recording time to ignore the effect of any initial - one-off costs such as just-in-time compilation. (default: False) + -benchmarks BENCHMARKS [BENCHMARKS ...] + Names of benchmark functions to run. All benchmarks are run if omitted. (default: None) ``` For example to run the spherical transform benchmarks using only the JAX implementations, @@ -65,7 +60,7 @@ running on a CPU (in double-precision) for `L` values 64, 128, 256, 512 and 1024 would run from the root of the repository: ```sh -JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py --run-once-and-discard -p L 64 128 256 512 1024 -p method jax +JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py -p L 64 128 256 512 1024 -p method jax ``` Note the usage of environment variables `JAX_PLATFORM_NAME` and `JAX_ENABLE_X64` to diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index faea7c1b..bbf4a312 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -1,25 +1,40 @@ """Helper functions and classes for benchmarking. -Functions to be benchmarked in a module should be decorated with `benchmark` which -takes one positional argument corresponding to a function to peform any necessary -set up for the benchmarked function (returning a dictionary, potentially empty with -any precomputed values to pass to benchmark function as keyword arguments) and zero -or more keyword arguments specifying parameter names and values lists to benchmark -over (the Cartesian product of all specified parameter values is used). The benchmark -function is passed the union of any precomputed values outputted by the setup function -and the parameters values as keyword arguments. - -As a simple example, the following defines a benchmarkfor computing the mean of a list +Functions to be benchmarked in a module should be decorated with `benchmark` which takes +one positional argument corresponding to a function to perform any necessary set up for +the benchmarked function and zero or more keyword arguments specifying parameter names +and values lists to benchmark over (the Cartesian product of all specified parameter +values is used). + +The benchmark setup function should return an instance of the named tuple type +`BenchmarkSetup` which consists of a required dictionary entry, potentially empty with +any precomputed values to pass to benchmark function as keyword arguments, and two +further optional entries: the first a reference value for the output of the benchmarked +function to use to compute numerical error if applicable, defaulting to `None` +indicating no applicable reference value; and the second a flag indicating whether to +just-in-time compile the benchmark function using JAX's `jit` transform, defaulting to +`False`. + +The benchmark function is passed the union of any precomputed values returned by the +setup function and the parameters values as keyword arguments. If a reference output +value is set by the setup function the benchmark function should output the value to +compare to this reference value by computing the maximum absolute elementwise +difference. If the function is to be just-in-time compiled using JAX the value returned +by the benchmark function should be a JAX array on which the `block_until_ready` method +may be called to ensure the function only exits once the relevant computation has +completed (necessary due to JAX's asynchronous dispatch computation model). + +As a simple example, the following defines a benchmark for computing the mean of a list of numbers. ```Python import random -from benchmarking import benchmark +from benchmarking import BenchmarkSetup, benchmark def setup_mean(n): - return {"x": [random.random() for _ in range(n)]} + return BenchmarkSetup({"x": [random.random() for _ in range(n)]}) -@benchmark(setup_computation, n=[1, 2, 3, 4]) +@benchmark(setup_mean, n=[1, 2, 3, 4]) def mean(x, n): return sum(x) / n ``` @@ -29,12 +44,12 @@ def mean(x, n): ```Python import random -from benchmarking import benchmark, skip +from benchmarking import BenchmarkSetup, benchmark, skip def setup_mean(n): - return {"x": [random.random() for _ in range(n)]} + return BenchmarkSetup({"x": [random.random() for _ in range(n)]}) -@benchmark(setup_computation, n=[0, 1, 2, 3, 4]) +@benchmark(setup_mean, n=[0, 1, 2, 3, 4]) def mean(x, n): if n == 0: skip("number of items must be positive") @@ -48,42 +63,43 @@ def mean(x, n): to allow it to be executed as a script for runnning the benchmarks: ```Python - from benchmarking import benchmark, parse_args_collect_and_run_benchmarks ... if __name__ == "__main__": parse_args_collect_and_run_benchmarks() - ``` """ +from __future__ import annotations + import argparse +import contextlib import datetime import inspect import json import platform import timeit +import tracemalloc from ast import literal_eval +from collections.abc import Callable, Iterable from functools import partial from importlib.metadata import PackageNotFoundError, version from itertools import product from pathlib import Path +from types import ModuleType +from typing import Any, NamedTuple -try: - import memory_profiler - - MEMORY_PROFILER_AVAILABLE = True -except ImportError: - MEMORY_PROFILER_AVAILABLE = False +import jax +import numpy as np class SkipBenchmarkException(Exception): """Exception to be raised to skip benchmark for some parameter set.""" -def _get_version_or_none(package_name): +def _get_version_or_none(package_name: str) -> str | None: """Get installed version of package or `None` if package not found.""" try: return version(package_name) @@ -101,16 +117,15 @@ def _get_cpu_info(): return None -def _get_gpu_memory_mebibytes(device): - """Try to get GPU memory available in mebibytes (MiB).""" +def _get_gpu_memory_in_bytes(device: jax.Device) -> int | None: + """Try to get GPU memory available in bytes.""" memory_stats = device.memory_stats() if memory_stats is None: return None - bytes_limit = memory_stats.get("bytes_limit") - return bytes_limit // 2**20 if bytes_limit is not None else None + return memory_stats.get("bytes_limit") -def _get_gpu_info(): +def _get_gpu_info() -> dict[str, str | int]: """Get details of GPU devices available from JAX or None if JAX not available.""" try: import jax @@ -118,7 +133,7 @@ def _get_gpu_info(): return [ { "kind": d.device_kind, - "memory_available / MiB": _get_gpu_memory_mebibytes(d), + "memory_available_in_bytes": _get_gpu_memory_in_bytes(d), } for d in jax.devices() if d.platform == "gpu" @@ -127,7 +142,7 @@ def _get_gpu_info(): return None -def _get_cuda_info(): +def _get_cuda_info() -> dict[str, str]: """Try to get information on versions of CUDA libraries.""" try: from jax._src.lib import cuda_versions @@ -146,7 +161,7 @@ def _get_cuda_info(): return None -def skip(message): +def skip(message: str) -> None: """Skip benchmark for a particular parameter set with explanatory message. Args: @@ -155,13 +170,30 @@ def skip(message): raise SkipBenchmarkException(message) -def benchmark(setup_=None, **parameters): - """Decorator for defining a function to be benchmarker +class BenchmarkSetup(NamedTuple): + """Structure containing data for setting up a benchmark function.""" + + arguments: dict[str, Any] + reference_output: None | jax.Array | np.ndarray = None + jit_benchmark: bool = False + + +def benchmark( + setup: Callable[..., BenchmarkSetup] | None = None, **parameters +) -> Callable: + """Decorator for defining a function to be benchmark. Args: - setup_: Function performing any necessary set up for benchmark, and the resource + setup: Function performing any necessary set up for benchmark, and the resource usage of which will not be tracked in benchmarking. The function should - return a dictionary of values to pass to the benchmark as keyword arguments. + return an instance of `BenchmarkSetup` named tuple, with first entry a + dictionary of values to pass to the benchmark as keyword arguments, + corresponding to any precomputed values, the second entry optionally a + reference value specifying the expected 'true' numerical output of the + benchmarked function to allow computing numerical error, or `None` if there + is no relevant reference value and third entry a boolean flag indicating + whether to use JAX's just-in-time compilation transform to benchmark + function. Kwargs: Parameter names and associated lists of values over which to run benchmark. @@ -169,29 +201,61 @@ def benchmark(setup_=None, **parameters): Returns: Decorator which marks function as benchmark and sets setup function and - parameters attributes. + parameters attributes. To also record error of benchmarked function in terms of + maximum absolute elementwise difference between output and reference value + returned by `setup` function, the decorated benchmark function should return + the numerical value to measure the error for. """ def decorator(function): function.is_benchmark = True - function.setup = setup_ if setup_ is not None else lambda: {} + function.setup = setup if setup is not None else lambda: {} function.parameters = parameters return function return decorator -def _parameters_string(parameters): +def _parameters_string(parameters: dict) -> str: """Format parameter values as string for printing benchmark results.""" return "(" + ", ".join(f"{name}: {val}" for name, val in parameters.items()) + ")" -def _dict_product(dicts): +def _format_results_entry(results_entry: dict) -> str: + """Format benchmark results entry as a string for printing.""" + return ( + ( + f"{_parameters_string(results_entry['parameters']):>40}: \n " + if len(results_entry["parameters"]) != 0 + else " " + ) + + f"min(run times): {min(results_entry['run_times_in_seconds']):>#7.2g}s, " + + f"max(run times): {max(results_entry['run_times_in_seconds']):>#7.2g}s" + + ( + f", peak memory: {results_entry['traced_memory_peak_in_bytes']:>#7.2g}B" + if "traced_memory_peak_in_bytes" in results_entry + else "" + ) + + ( + f", max(abs(error)): {results_entry['max_abs_error']:>#7.2g}" + if "max_abs_error" in results_entry + else "" + ) + + ( + f", floating point ops: {results_entry['cost_analysis']['flops']:>#7.2g}" + f", mem access: {results_entry['cost_analysis']['bytes_accessed']:>#7.2g}B" + if "cost_analysis" in results_entry + else "" + ) + ) + + +def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]: """Generator corresponding to Cartesian product of dictionaries.""" return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values())) -def _parse_value(value): +def _parse_value(value: str) -> Any: """Parse a value passed at command line as a Python literal or string as fallback""" try: return literal_eval(value) @@ -199,7 +263,7 @@ def _parse_value(value): return str(value) -def _parse_parameter_overrides(parameter_overrides): +def _parse_parameter_overrides(parameter_overrides: list[str]) -> dict[str, Any]: """Parse any parameter override values passed as command line arguments""" return ( { @@ -211,7 +275,7 @@ def _parse_parameter_overrides(parameter_overrides): ) -def _parse_cli_arguments(description): +def _parse_cli_arguments(description: str) -> argparse.Namespace: """Parse command line arguments passed for controlling benchmark runs""" parser = argparse.ArgumentParser( description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -243,18 +307,14 @@ def _parse_cli_arguments(description): "-output-file", type=Path, help="File path to write JSON formatted results to." ) parser.add_argument( - "--run-once-and-discard", - action="store_true", - help=( - "Run benchmark function once first without recording time to " - "ignore the effect of any initial one-off costs such as just-in-time " - "compilation." - ), + "-benchmarks", + nargs="+", + help="Names of benchmark functions to run. All benchmarks are run if omitted.", ) return parser.parse_args() -def _is_benchmark(object): +def _is_benchmark(object: Any) -> bool: """Predicate for testing whether an object is a benchmark function or not.""" return ( inspect.isfunction(object) @@ -263,26 +323,87 @@ def _is_benchmark(object): ) -def collect_benchmarks(module): +def collect_benchmarks( + module: ModuleType, benchmark_names: list[str] +) -> list[Callable]: """Collect all benchmark functions from a module. Args: module: Python module containing benchmark functions. + benchmark_names: List of benchmark names to collect or `None` if all benchmarks + in module to be collected. Returns: List of functions in module with `is_benchmark` attribute set to `True`. """ - return [function for name, function in inspect.getmembers(module, _is_benchmark)] + return [ + function + for name, function in inspect.getmembers(module, _is_benchmark) + if benchmark_names is None or name in benchmark_names + ] + + +@contextlib.contextmanager +def trace_memory_allocations(n_frames: int = 1) -> Callable[[], tuple[int, int]]: + """Context manager for tracing memory allocations in managed with block. + + Returns a thunk (zero-argument function) which can be called on exit from with block + to get tuple of current size and peak size of memory blocks traced in bytes. + + Args: + n_frames: Limit on depth of frames to trace memory allocations in. + + Returns: + A thunk (zero-argument function) which can be called on exit from with block to + get tuple of current size and peak size of memory blocks traced in bytes. + """ + tracemalloc.start(n_frames) + current_size, peak_size = None, None + try: + yield lambda: (current_size, peak_size) + current_size, peak_size = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() + + +def _compile_jax_benchmark_and_analyse( + benchmark_function: Callable, results_entry: dict +) -> Callable: + """Compile a JAX benchmark function and extract cost estimates if available.""" + compiled_benchmark_function = jax.jit(benchmark_function).lower().compile() + cost_analysis = compiled_benchmark_function.cost_analysis() + if cost_analysis is not None: + if isinstance(cost_analysis, list): + cost_analysis = cost_analysis[0] + results_entry["cost_analysis"] = { + "flops": cost_analysis.get("flops"), + "bytes_accessed": cost_analysis.get("bytes accessed"), + } + memory_analysis = compiled_benchmark_function.memory_analysis() + if memory_analysis is not None: + results_entry["memory_analysis"] = { + prefix + base_key: getattr(memory_analysis, prefix + base_key, None) + for prefix in ("", "host_") + for base_key in ( + "alias_size_in_bytes", + "argument_size_in_bytes", + "generated_code_size_in_bytes", + "output_size_in_bytes", + "temp_size_in_bytes", + ) + } + # Ensure block_until_ready called on benchmark output due to JAX's asynchronous + # dispatch model: https://jax.readthedocs.io/en/latest/async_dispatch.html + return lambda: compiled_benchmark_function().block_until_ready() def run_benchmarks( - benchmarks, - number_runs, - number_repeats, - print_results=True, - parameter_overrides=None, - run_once_and_discard=False, -): + benchmarks: list[Callable], + number_runs: int, + number_repeats: int, + print_results: bool = True, + parameter_overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: """Run a set of benchmarks. Args: @@ -295,9 +416,6 @@ def run_benchmarks( print_results: Whether to print benchmark results to stdout. parameter_overrides: Dictionary specifying any overrides for parameter values set in `benchmark` decorator. - run_once_and_discard: Whether to run benchmark function once first without - recording time to ignore the effect of any initial one-off costs such as - just-in-time compilation. Returns: Dictionary containing timing (and potentially memory usage) results for each @@ -310,49 +428,42 @@ def run_benchmarks( print(benchmark.__name__) parameters = benchmark.parameters.copy() if parameter_overrides is not None: - parameters.update(parameter_overrides) + for parameter_name, parameter_values in parameter_overrides.items(): + if parameter_name in parameters: + parameters[parameter_name] = parameter_values for parameter_set in _dict_product(parameters): try: - precomputes = benchmark.setup(**parameter_set) - benchmark_function = partial(benchmark, **precomputes, **parameter_set) - if run_once_and_discard: - benchmark_function() + args, reference_output, jit_benchmark = benchmark.setup(**parameter_set) + benchmark_function = partial(benchmark, **args, **parameter_set) + results_entry = {"parameters": parameter_set} + if jit_benchmark: + benchmark_function = _compile_jax_benchmark_and_analyse( + benchmark_function, results_entry + ) + # Run benchmark once without timing to record output for potentially + # computing numerical error and trace memory usage + with trace_memory_allocations() as traced_memory: + output = benchmark_function() + current_size, peak_size = traced_memory() + results_entry["traced_memory_final_in_bytes"] = current_size + results_entry["traced_memory_peak_in_bytes"] = peak_size + if reference_output is not None and output is not None: + results_entry["max_abs_error"] = float( + abs(reference_output - output).max() + ) + results_entry["mean_abs_error"] = float( + abs(reference_output - output).mean() + ) run_times = [ time / number_runs for time in timeit.repeat( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results_entry = {**parameter_set, "times / s": run_times} - if MEMORY_PROFILER_AVAILABLE: - baseline_memory = memory_profiler.memory_usage(max_usage=True) - peak_memory = ( - memory_profiler.memory_usage( - benchmark_function, - interval=max(run_times) * number_repeats, - max_usage=True, - max_iterations=number_repeats, - include_children=True, - ) - - baseline_memory - ) - results_entry["peak_memory / MiB"] = peak_memory + results_entry["run_times_in_seconds"] = run_times results[benchmark.__name__].append(results_entry) if print_results: - print( - ( - f"{_parameters_string(parameter_set):>40}: \n " - if len(parameter_set) != 0 - else " " - ) - + f"min(time): {min(run_times):>#7.2g}s, " - + f"max(time): {max(run_times):>#7.2g}s" - + ( - f", peak mem.: {peak_memory:>#7.2g}MiB" - if MEMORY_PROFILER_AVAILABLE - else "" - ) - ) + print(_format_results_entry(results_entry)) except SkipBenchmarkException as e: if print_results: print( @@ -361,16 +472,58 @@ def run_benchmarks( return results -def parse_args_collect_and_run_benchmarks(module=None): +def get_system_info() -> dict[str, Any]: + """Get dictionary of metadata about system. + + Returns: + Dictionary with information about system, CPU and GPU devices (if present) and + Python environment and package versions. + """ + package_versions = { + f"{package}_version": _get_version_or_none(package) + for package in ("s2fft", "jax", "numpy") + } + return { + "architecture": platform.architecture(), + "machine": platform.machine(), + "node": platform.node(), + "processor": platform.processor(), + "python_version": platform.python_version(), + "release": platform.release(), + "system": platform.system(), + "cpu_info": _get_cpu_info(), + "gpu_info": _get_gpu_info(), + "cuda_info": _get_cuda_info(), + **package_versions, + } + + +def write_json_results_file( + output_file: Path, results: dict[str, Any], benchmark_module: str +) -> None: + """Write benchmark results and system information to a file in JSON format. + + Args: + output_file: Path to file to write results to. + results: Dictionary of benchmark results from `run_benchmarks`. + benchmarks_module: Name of module containing benchmarks. + """ + with open(output_file, "w") as f: + output = { + "date_time": datetime.datetime.now().isoformat(), + "benchmark_module": benchmark_module, + "system_info": get_system_info(), + "results": results, + } + json.dump(output, f, indent=True) + + +def parse_args_collect_and_run_benchmarks(module: ModuleType | None = None) -> None: """Collect and run all benchmarks in a module and parse command line arguments. Args: module: Module containing benchmarks to run. Defaults to module from which this function was called if not specified (set to `None`). - - Returns: - Dictionary containing timing (and potentially memory usage) results for each - parameters set of each benchmark function. """ if module is None: frame = inspect.stack()[1] @@ -378,36 +531,11 @@ def parse_args_collect_and_run_benchmarks(module=None): args = _parse_cli_arguments(module.__doc__) parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) results = run_benchmarks( - benchmarks=collect_benchmarks(module), + benchmarks=collect_benchmarks(module, args.benchmarks), number_runs=args.number_runs, number_repeats=args.repeats, + print_results=True, parameter_overrides=parameter_overrides, - run_once_and_discard=args.run_once_and_discard, ) if args.output_file is not None: - package_versions = { - f"{package}_version": _get_version_or_none(package) - for package in ("s2fft", "jax", "numpy") - } - system_info = { - "architecture": platform.architecture(), - "machine": platform.machine(), - "node": platform.node(), - "processor": platform.processor(), - "python_version": platform.python_version(), - "release": platform.release(), - "system": platform.system(), - "cpu_info": _get_cpu_info(), - "gpu_info": _get_gpu_info(), - "cuda_info": _get_cuda_info(), - **package_versions, - } - with open(args.output_file, "w") as f: - output = { - "date_time": datetime.datetime.now().isoformat(), - "benchmark_module": module.__name__, - "system_info": system_info, - "results": results, - } - json.dump(output, f, indent=True) - return results + write_json_results_file(args.output_file, results, module.__name__) diff --git a/benchmarks/plotting.py b/benchmarks/plotting.py new file mode 100644 index 00000000..75c9aa33 --- /dev/null +++ b/benchmarks/plotting.py @@ -0,0 +1,209 @@ +"""Utilities for plotting benchmark results.""" + +import argparse +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +def _set_axis_properties( + ax: plt.Axes, + parameter_values: np.ndarray, + parameter_label: str, + measurement_label: str, +) -> None: + ax.set( + xlabel=parameter_label, + ylabel=measurement_label, + xscale="log", + yscale="log", + xticks=parameter_values, + xticklabels=parameter_values, + ) + ax.minorticks_off() + + +def _plot_scaling_guide( + ax: plt.Axes, + parameter_symbol: str, + parameter_values: np.ndarray, + measurement_values: np.ndarray, + order: int, +) -> None: + n = np.argsort(parameter_values)[len(parameter_values) // 2] + coefficient = measurement_values[n] / float(parameter_values[n]) ** order + ax.plot( + parameter_values, + coefficient * parameter_values.astype(float) ** order, + "k:", + label=f"$\\mathcal{{O}}({parameter_symbol}^{order})$", + ) + + +def plot_times( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + min_times = np.array([min(r["run_times_in_seconds"]) for r in results]) + mid_times = np.array([np.median(r["run_times_in_seconds"]) for r in results]) + max_times = np.array([max(r["run_times_in_seconds"]) for r in results]) + ax.plot(parameter_values, mid_times, label="Measured") + ax.fill_between(parameter_values, min_times, max_times, alpha=0.5) + _plot_scaling_guide(ax, parameter_symbol, parameter_values, mid_times, 3) + ax.legend() + + +def plot_flops( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + flops = np.array([r["cost_analysis"]["flops"] for r in results]) + ax.plot(parameter_values, flops, label="Measured") + _plot_scaling_guide(ax, parameter_symbol, parameter_values, flops, 2) + ax.legend() + + +def plot_error( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + max_abs_errors = np.array([r["max_abs_error"] for r in results]) + mean_abs_errors = np.array([r["mean_abs_error"] for r in results]) + ax.plot(parameter_values, max_abs_errors, label="max(abs(error))") + ax.plot(parameter_values, mean_abs_errors, label="mean(abs(error))") + _plot_scaling_guide( + ax, + parameter_symbol, + parameter_values, + (max_abs_errors + mean_abs_errors) / 2, + 2, + ) + ax.legend() + + +def plot_memory( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + bytes_accessed = np.array([r["cost_analysis"]["bytes_accessed"] for r in results]) + temp_size_in_bytes = np.array( + [r["memory_analysis"]["temp_size_in_bytes"] for r in results] + ) + output_size_in_bytes = np.array( + [r["memory_analysis"]["output_size_in_bytes"] for r in results] + ) + generated_code_size_in_bytes = np.array( + [r["memory_analysis"]["generated_code_size_in_bytes"] for r in results] + ) + ax.plot(parameter_values, bytes_accessed, label="Accesses") + ax.plot(parameter_values, temp_size_in_bytes, label="Temporary allocations") + ax.plot(parameter_values, output_size_in_bytes, label="Output size") + ax.plot(parameter_values, generated_code_size_in_bytes, label="Generated code size") + _plot_scaling_guide( + ax, + parameter_symbol, + parameter_values, + (bytes_accessed + output_size_in_bytes) / 2, + 2, + ) + ax.legend() + + +_measurement_plot_functions_and_labels = { + "times": (plot_times, "Run time / s"), + "flops": (plot_flops, "Floating point operations"), + "memory": (plot_memory, "Memory / B"), + "error": (plot_error, "Numerical error"), +} + + +def plot_results_against_bandlimit( + benchmark_results_path: str | Path, + functions: tuple[str] = ("forward", "inverse"), + measurements: tuple[str] = ("times", "flops", "memory", "error"), + axis_size: float = 3.0, + fig_dpi: int = 100, +) -> tuple[plt.Figure, plt.Axes]: + benchmark_results_path = Path(benchmark_results_path) + with benchmark_results_path.open("r") as f: + benchmark_results = json.load(f) + n_functions = len(functions) + n_measurements = len(measurements) + fig, axes = plt.subplots( + n_functions, + n_measurements, + figsize=(axis_size * n_measurements, axis_size * n_functions), + dpi=fig_dpi, + squeeze=False, + ) + for axes_row, function in zip(axes, functions): + results = benchmark_results["results"][function] + l_values = np.array([r["parameters"]["L"] for r in results]) + for ax, measurement in zip(axes_row, measurements): + plot_function, label = _measurement_plot_functions_and_labels[measurement] + try: + plot_function(ax, "L", l_values, results) + ax.set(title=function) + except KeyError: + ax.axis("off") + _set_axis_properties(ax, l_values, "Bandlimit $L$", label) + return fig, ax + + +def _parse_cli_arguments() -> argparse.Namespace: + """Parse rguments passed for plotting command line interface""" + parser = argparse.ArgumentParser( + description="Generate plot from benchmark results file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-results-path", + type=Path, + help="Path to JSON file containing benchmark results to plot.", + ) + parser.add_argument( + "-output-path", + type=Path, + help="Path to write figure to.", + ) + parser.add_argument( + "-functions", + nargs="+", + help="Names of functions to plot. forward and inverse are plotted if omitted.", + ) + parser.add_argument( + "-measurements", + nargs="+", + help="Names of measurements to plot. All functions are plotted if omitted.", + ) + parser.add_argument( + "-axis-size", type=float, default=5.0, help="Size of each plot axis in inches." + ) + parser.add_argument( + "-dpi", type=int, default=100, help="Figure resolution in dots per inch." + ) + parser.add_argument( + "-title", type=str, help="Title for figure. No title added if omitted." + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_cli_arguments() + functions = ( + ("forward", "inverse") if args.functions is None else tuple(args.functions) + ) + measurements = ( + ("times", "flops", "memory", "error") + if args.measurements is None + else tuple(args.measurements) + ) + fig, _ = plot_results_against_bandlimit( + args.results_path, + functions=functions, + measurements=measurements, + axis_size=args.axis_size, + fig_dpi=args.dpi, + ) + if args.title is not None: + fig.suptitle(args.title) + fig.tight_layout() + fig.savefig(args.output_path) diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 1af2a2aa..bec1dd7c 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -1,12 +1,15 @@ """Benchmarks for precompute spherical transforms.""" import numpy as np -import pyssht -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, + skip, +) import s2fft import s2fft.precompute_transforms -from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] SPIN_VALUES = [0] @@ -21,12 +24,12 @@ def setup_forward(method, L, sampling, spin, reality, recursion): skip("Reality only valid for scalar fields (spin=0).") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - f = pyssht.inverse( - samples.flm_2d_to_1d(flm, L), - L, - Method=sampling.upper(), - Spin=spin, - Reality=reality, + f = s2fft.transforms.spherical.inverse( + flm, + L=L, + spin=spin, + sampling=sampling, + reality=reality, ) kernel_function = ( s2fft.precompute_transforms.construct.spin_spherical_kernel_jax @@ -41,7 +44,7 @@ def setup_forward(method, L, sampling, spin, reality, recursion): forward=True, recursion=recursion, ) - return {"f": f, "kernel": kernel} + return BenchmarkSetup({"f": f, "kernel": kernel}, flm, "jax" in method) @benchmark( @@ -54,7 +57,7 @@ def setup_forward(method, L, sampling, spin, reality, recursion): recursion=RECURSION_VALUES, ) def forward(f, kernel, method, L, sampling, spin, reality, recursion): - flm = s2fft.precompute_transforms.spherical.forward( + return s2fft.precompute_transforms.spherical.forward( f=f, L=L, spin=spin, @@ -63,8 +66,6 @@ def forward(f, kernel, method, L, sampling, spin, reality, recursion): reality=reality, method=method, ) - if method == "jax": - flm.block_until_ready() def setup_inverse(method, L, sampling, spin, reality, recursion): @@ -85,7 +86,7 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): forward=False, recursion=recursion, ) - return {"flm": flm, "kernel": kernel} + return BenchmarkSetup({"flm": flm, "kernel": kernel}, None, "jax" in method) @benchmark( @@ -98,7 +99,7 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): recursion=RECURSION_VALUES, ) def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): - f = s2fft.precompute_transforms.spherical.inverse( + return s2fft.precompute_transforms.spherical.inverse( flm=flm, L=L, spin=spin, @@ -107,8 +108,6 @@ def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): reality=reality, method=method, ) - if method == "jax": - f.block_until_ready() if __name__ == "__main__": diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index 01918ca7..73971d43 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -1,7 +1,11 @@ """Benchmarks for precompute Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, +) import s2fft import s2fft.precompute_transforms @@ -29,13 +33,13 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): ) kernel_function = ( s2fft.precompute_transforms.construct.wigner_kernel_jax - if method == "jax" + if "jax" in method else s2fft.precompute_transforms.construct.wigner_kernel ) kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode ) - return {"f": f, "kernel": kernel} + return BenchmarkSetup({"f": f, "kernel": kernel}, flmn, "jax" in method) @benchmark( @@ -49,7 +53,7 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): mode=MODE_VALUES, ) def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): - flmn = s2fft.precompute_transforms.wigner.forward( + return s2fft.precompute_transforms.wigner.forward( f=f, L=L, N=N, @@ -58,8 +62,6 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): reality=reality, method=method, ) - if method == "jax": - flmn.block_until_ready() def setup_inverse(method, L, N, L_lower, sampling, reality, mode): @@ -73,7 +75,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality, mode): kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode ) - return {"flmn": flmn, "kernel": kernel} + return BenchmarkSetup({"flmn": flmn, "kernel": kernel}, None, "jax" in method) @benchmark( @@ -87,7 +89,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality, mode): mode=MODE_VALUES, ) def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): - f = s2fft.precompute_transforms.wigner.inverse( + return s2fft.precompute_transforms.wigner.inverse( flmn=flmn, L=L, N=N, @@ -96,8 +98,6 @@ def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): reality=reality, method=method, ) - if method == "jax": - f.block_until_ready() if __name__ == "__main__": diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index 78231a43..d0ea078d 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -1,46 +1,62 @@ """Benchmarks for on-the-fly spherical transforms.""" import numpy as np -import pyssht -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, + skip, +) import s2fft from s2fft.recursions.price_mcewen import generate_precomputes_jax -from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] L_LOWER_VALUES = [0] SPIN_VALUES = [0] +L_TO_NSIDE_RATIO_VALUES = [2] SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] SPMD_VALUES = [False] +N_ITER_VALUES = [None] def _jax_arrays_to_numpy(precomps): return [np.asarray(p) for p in precomps] -def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): +def _get_nside(sampling, L, L_to_nside_ratio): + return None if sampling != "healpix" else L // L_to_nside_ratio + + +def setup_forward( + method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd, n_iter +): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") if spmd and method != "jax": skip("GPU distribution only valid for JAX.") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - f = pyssht.inverse( - samples.flm_2d_to_1d(flm, L), - L, - Method=sampling.upper(), - Spin=spin, - Reality=reality, + nside = _get_nside(sampling, L, L_to_nside_ratio) + f = s2fft.transforms.spherical.inverse( + flm, + L=L, + spin=spin, + nside=nside, + sampling=sampling, + method=method, + reality=reality, + spmd=spmd, + L_lower=L_lower, ) precomps = generate_precomputes_jax( - L, spin, sampling, forward=True, L_lower=L_lower + L, spin, sampling, nside=nside, forward=True, L_lower=L_lower ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"f": f, "precomps": precomps} + return BenchmarkSetup({"f": f, "precomps": precomps}, flm, "jax" in method) @benchmark( @@ -50,29 +66,40 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, spin=SPIN_VALUES, + L_to_nside_ratio=L_TO_NSIDE_RATIO_VALUES, reality=REALITY_VALUES, spmd=SPMD_VALUES, + n_iter=N_ITER_VALUES, ) -def forward(f, precomps, method, L, L_lower, sampling, spin, reality, spmd): - if method == "pyssht": - flm = pyssht.forward(f, L, spin, sampling.upper()) - else: - flm = s2fft.transforms.spherical.forward( - f=f, - L=L, - L_lower=L_lower, - precomps=precomps, - spin=spin, - sampling=sampling, - reality=reality, - method=method, - spmd=spmd, - ) - if method == "jax": - flm.block_until_ready() - - -def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): +def forward( + f, + precomps, + method, + L, + L_lower, + sampling, + spin, + L_to_nside_ratio, + reality, + spmd, + n_iter, +): + return s2fft.transforms.spherical.forward( + f=f, + L=L, + L_lower=L_lower, + precomps=precomps, + spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), + sampling=sampling, + reality=reality, + method=method, + spmd=spmd, + iter=n_iter, + ) + + +def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") if spmd and method != "jax": @@ -80,11 +107,16 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) precomps = generate_precomputes_jax( - L, spin, sampling, forward=False, L_lower=L_lower + L, + spin, + sampling, + nside=_get_nside(sampling, L, L_to_nside_ratio), + forward=False, + L_lower=L_lower, ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"flm": flm, "precomps": precomps} + return BenchmarkSetup({"flm": flm, "precomps": precomps}, None, "jax" in method) @benchmark( @@ -94,26 +126,25 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, spin=SPIN_VALUES, + L_to_nside_ratio=L_TO_NSIDE_RATIO_VALUES, reality=REALITY_VALUES, spmd=SPMD_VALUES, ) -def inverse(flm, precomps, method, L, L_lower, sampling, spin, reality, spmd): - if method == "pyssht": - f = pyssht.inverse(samples.flm_2d_to_1d(flm, L), L, spin, sampling.upper()) - else: - f = s2fft.transforms.spherical.inverse( - flm=flm, - L=L, - L_lower=L_lower, - precomps=precomps, - spin=spin, - sampling=sampling, - reality=reality, - method=method, - spmd=spmd, - ) - if method == "jax": - f.block_until_ready() +def inverse( + flm, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd +): + return s2fft.transforms.spherical.inverse( + flm=flm, + L=L, + L_lower=L_lower, + precomps=precomps, + spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), + sampling=sampling, + reality=reality, + method=method, + spmd=spmd, + ) if __name__ == "__main__": diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index d6180961..89655c79 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -1,7 +1,11 @@ """Benchmarks for on-the-fly Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, +) import s2fft from s2fft.base_transforms import wigner as base_wigner @@ -31,13 +35,13 @@ def setup_forward(method, L, L_lower, N, sampling, reality): ) generate_precomputes = ( generate_precomputes_wigner_jax - if method == "jax" + if "jax" in method else generate_precomputes_wigner ) precomps = generate_precomputes( L, N, sampling, forward=True, reality=reality, L_lower=L_lower ) - return {"f": f, "precomps": precomps} + return BenchmarkSetup({"f": f, "precomps": precomps}, flmn, "jax" in method) @benchmark( @@ -50,7 +54,7 @@ def setup_forward(method, L, L_lower, N, sampling, reality): reality=REALITY_VALUES, ) def forward(f, precomps, method, L, L_lower, N, sampling, reality): - flmn = s2fft.transforms.wigner.forward( + return s2fft.transforms.wigner.forward( f=f, L=L, N=N, @@ -60,8 +64,6 @@ def forward(f, precomps, method, L, L_lower, N, sampling, reality): precomps=precomps, L_lower=L_lower, ) - if method == "jax": - flmn.block_until_ready() def setup_inverse(method, L, L_lower, N, sampling, reality): @@ -69,13 +71,13 @@ def setup_inverse(method, L, L_lower, N, sampling, reality): flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) generate_precomputes = ( generate_precomputes_wigner_jax - if method == "jax" + if "jax" in method else generate_precomputes_wigner ) precomps = generate_precomputes( L, N, sampling, forward=False, reality=reality, L_lower=L_lower ) - return {"flmn": flmn, "precomps": precomps} + return BenchmarkSetup({"flmn": flmn, "precomps": precomps}, None, "jax" in method) @benchmark( @@ -88,7 +90,7 @@ def setup_inverse(method, L, L_lower, N, sampling, reality): reality=REALITY_VALUES, ) def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): - f = s2fft.transforms.wigner.inverse( + return s2fft.transforms.wigner.inverse( flmn=flmn, L=L, N=N, @@ -98,8 +100,6 @@ def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): precomps=precomps, L_lower=L_lower, ) - if method == "jax": - f.block_until_ready() if __name__ == "__main__":