From f9ce1b465f7cb0a7b9122e224c9211ee008e809f Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 08:58:54 +0000 Subject: [PATCH 01/22] Allow computing numerical error in benchmarks --- benchmarks/benchmarking.py | 32 +++++++++++++++++++++--------- benchmarks/precompute_spherical.py | 6 ++++-- benchmarks/precompute_wigner.py | 6 ++++-- benchmarks/spherical.py | 6 ++++-- benchmarks/wigner.py | 6 ++++-- 5 files changed, 39 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index faea7c1b..9fead05b 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -9,7 +9,7 @@ function is passed the union of any precomputed values outputted by the setup function and the parameters values as keyword arguments. -As a simple example, the following defines a benchmarkfor computing the mean of a list +As a simple example, the following defines a benchmark for computing the mean of a list of numbers. ```Python @@ -17,7 +17,7 @@ from benchmarking import benchmark def setup_mean(n): - return {"x": [random.random() for _ in range(n)]} + return {"x": [random.random() for _ in range(n)]}, None @benchmark(setup_computation, n=[1, 2, 3, 4]) def mean(x, n): @@ -32,7 +32,7 @@ def mean(x, n): from benchmarking import benchmark, skip def setup_mean(n): - return {"x": [random.random() for _ in range(n)]} + return {"x": [random.random() for _ in range(n)]}, None @benchmark(setup_computation, n=[0, 1, 2, 3, 4]) def mean(x, n): @@ -156,12 +156,16 @@ def skip(message): def benchmark(setup_=None, **parameters): - """Decorator for defining a function to be benchmarker + """Decorator for defining a function to be benchmark. Args: setup_: Function performing any necessary set up for benchmark, and the resource usage of which will not be tracked in benchmarking. The function should - return a dictionary of values to pass to the benchmark as keyword arguments. + return a 2-tuple, with first entry a dictionary of values to pass to the + benchmark as keyword arguments, corresponding to any precomputed values, + and the second entry optionally a reference value specifying the expected + 'true' numerical output of the behchmarked function to allow computing + numerical error, or `None` if there is no relevant reference value. Kwargs: Parameter names and associated lists of values over which to run benchmark. @@ -169,7 +173,10 @@ def benchmark(setup_=None, **parameters): Returns: Decorator which marks function as benchmark and sets setup function and - parameters attributes. + parameters attributes. To also record error of benchmarked function in terms of + maximum absolute elementwise difference between output and reference value + returned by `setup_` function, the decorated benchmark function should return + the numerical value to measure the error for. """ def decorator(function): @@ -313,10 +320,10 @@ def run_benchmarks( parameters.update(parameter_overrides) for parameter_set in _dict_product(parameters): try: - precomputes = benchmark.setup(**parameter_set) + precomputes, reference_output = benchmark.setup(**parameter_set) benchmark_function = partial(benchmark, **precomputes, **parameter_set) - if run_once_and_discard: - benchmark_function() + if run_once_and_discard or reference_output is not None: + output = benchmark_function() run_times = [ time / number_runs for time in timeit.repeat( @@ -324,6 +331,8 @@ def run_benchmarks( ) ] results_entry = {**parameter_set, "times / s": run_times} + if reference_output is not None and output is not None: + results_entry["error"] = abs(reference_output - output).max() if MEMORY_PROFILER_AVAILABLE: baseline_memory = memory_profiler.memory_usage(max_usage=True) peak_memory = ( @@ -352,6 +361,11 @@ def run_benchmarks( if MEMORY_PROFILER_AVAILABLE else "" ) + + ( + f", round-trip error: {results_entry['error']:#7.2g}" + if "error" in results_entry + else "" + ) ) except SkipBenchmarkException as e: if print_results: diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 1af2a2aa..4fbd80a1 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -41,7 +41,7 @@ def setup_forward(method, L, sampling, spin, reality, recursion): forward=True, recursion=recursion, ) - return {"f": f, "kernel": kernel} + return {"f": f, "kernel": kernel}, flm @benchmark( @@ -65,6 +65,7 @@ def forward(f, kernel, method, L, sampling, spin, reality, recursion): ) if method == "jax": flm.block_until_ready() + return flm def setup_inverse(method, L, sampling, spin, reality, recursion): @@ -85,7 +86,7 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): forward=False, recursion=recursion, ) - return {"flm": flm, "kernel": kernel} + return {"flm": flm, "kernel": kernel}, None @benchmark( @@ -109,6 +110,7 @@ def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): ) if method == "jax": f.block_until_ready() + return f if __name__ == "__main__": diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index 01918ca7..60a401e7 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -35,7 +35,7 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode ) - return {"f": f, "kernel": kernel} + return {"f": f, "kernel": kernel}, flmn @benchmark( @@ -60,6 +60,7 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): ) if method == "jax": flmn.block_until_ready() + return flmn def setup_inverse(method, L, N, L_lower, sampling, reality, mode): @@ -73,7 +74,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality, mode): kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode ) - return {"flmn": flmn, "kernel": kernel} + return {"flmn": flmn, "kernel": kernel}, None @benchmark( @@ -98,6 +99,7 @@ def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): ) if method == "jax": f.block_until_ready() + return f if __name__ == "__main__": diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index 78231a43..39d9ab11 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -40,7 +40,7 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"f": f, "precomps": precomps} + return {"f": f, "precomps": precomps}, flm @benchmark( @@ -70,6 +70,7 @@ def forward(f, precomps, method, L, L_lower, sampling, spin, reality, spmd): ) if method == "jax": flm.block_until_ready() + return flm def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): @@ -84,7 +85,7 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"flm": flm, "precomps": precomps} + return {"flm": flm, "precomps": precomps}, None @benchmark( @@ -114,6 +115,7 @@ def inverse(flm, precomps, method, L, L_lower, sampling, spin, reality, spmd): ) if method == "jax": f.block_until_ready() + return f if __name__ == "__main__": diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index d6180961..af71033b 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -37,7 +37,7 @@ def setup_forward(method, L, L_lower, N, sampling, reality): precomps = generate_precomputes( L, N, sampling, forward=True, reality=reality, L_lower=L_lower ) - return {"f": f, "precomps": precomps} + return {"f": f, "precomps": precomps}, flmn @benchmark( @@ -62,6 +62,7 @@ def forward(f, precomps, method, L, L_lower, N, sampling, reality): ) if method == "jax": flmn.block_until_ready() + return flmn def setup_inverse(method, L, L_lower, N, sampling, reality): @@ -75,7 +76,7 @@ def setup_inverse(method, L, L_lower, N, sampling, reality): precomps = generate_precomputes( L, N, sampling, forward=False, reality=reality, L_lower=L_lower ) - return {"flmn": flmn, "precomps": precomps} + return {"flmn": flmn, "precomps": precomps}, None @benchmark( @@ -100,6 +101,7 @@ def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): ) if method == "jax": f.block_until_ready() + return f if __name__ == "__main__": From 2df7205e12ed99f624ecbdf2a7d1901fd56cb068 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 08:59:42 +0000 Subject: [PATCH 02/22] Use internal functions for benchmark setup rather than pyssht --- benchmarks/precompute_spherical.py | 14 ++++++-------- benchmarks/spherical.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 4fbd80a1..6202fe29 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -1,12 +1,10 @@ """Benchmarks for precompute spherical transforms.""" import numpy as np -import pyssht from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip import s2fft import s2fft.precompute_transforms -from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] SPIN_VALUES = [0] @@ -21,12 +19,12 @@ def setup_forward(method, L, sampling, spin, reality, recursion): skip("Reality only valid for scalar fields (spin=0).") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - f = pyssht.inverse( - samples.flm_2d_to_1d(flm, L), - L, - Method=sampling.upper(), - Spin=spin, - Reality=reality, + f = s2fft.transforms.spherical.inverse( + flm, + L=L, + spin=spin, + sampling=sampling, + reality=reality, ) kernel_function = ( s2fft.precompute_transforms.construct.spin_spherical_kernel_jax diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index 39d9ab11..2a15a99a 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -28,12 +28,14 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): skip("GPU distribution only valid for JAX.") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) - f = pyssht.inverse( - samples.flm_2d_to_1d(flm, L), - L, - Method=sampling.upper(), - Spin=spin, - Reality=reality, + f = s2fft.transforms.spherical.inverse( + flm, + L=L, + spin=spin, + sampling=sampling, + reality=reality, + spmd=spmd, + L_lower=L_lower, ) precomps = generate_precomputes_jax( L, spin, sampling, forward=True, L_lower=L_lower From a778204714d14c51de8dbff5b7521eb2250916ff Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 09:21:05 +0000 Subject: [PATCH 03/22] Expose parameter in benchmarks to set HEALPix nside --- benchmarks/spherical.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index 2a15a99a..31cdc70d 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -11,6 +11,7 @@ L_VALUES = [8, 16, 32, 64, 128, 256] L_LOWER_VALUES = [0] SPIN_VALUES = [0] +L_TO_NSIDE_RATIO_VALUES = [2] SAMPLING_VALUES = ["mw"] METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] @@ -21,24 +22,30 @@ def _jax_arrays_to_numpy(precomps): return [np.asarray(p) for p in precomps] -def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): +def _get_nside(sampling, L, L_to_nside_ratio): + return None if sampling != "healpix" else L // L_to_nside_ratio + + +def setup_forward(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") if spmd and method != "jax": skip("GPU distribution only valid for JAX.") rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) + nside = _get_nside(sampling, L, L_to_nside_ratio) f = s2fft.transforms.spherical.inverse( flm, L=L, spin=spin, + nside=nside, sampling=sampling, reality=reality, spmd=spmd, L_lower=L_lower, ) precomps = generate_precomputes_jax( - L, spin, sampling, forward=True, L_lower=L_lower + L, spin, sampling, nside=nside, forward=True, L_lower=L_lower ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) @@ -52,10 +59,13 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, spin=SPIN_VALUES, + L_to_nside_ratio=L_TO_NSIDE_RATIO_VALUES, reality=REALITY_VALUES, spmd=SPMD_VALUES, ) -def forward(f, precomps, method, L, L_lower, sampling, spin, reality, spmd): +def forward( + f, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd +): if method == "pyssht": flm = pyssht.forward(f, L, spin, sampling.upper()) else: @@ -65,6 +75,7 @@ def forward(f, precomps, method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_lower, precomps=precomps, spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), sampling=sampling, reality=reality, method=method, @@ -75,7 +86,7 @@ def forward(f, precomps, method, L, L_lower, sampling, spin, reality, spmd): return flm -def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): +def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") if spmd and method != "jax": @@ -83,7 +94,12 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): rng = np.random.default_rng() flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality) precomps = generate_precomputes_jax( - L, spin, sampling, forward=False, L_lower=L_lower + L, + spin, + sampling, + nside=_get_nside(sampling, L, L_to_nside_ratio), + forward=False, + L_lower=L_lower, ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) @@ -97,10 +113,13 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_LOWER_VALUES, sampling=SAMPLING_VALUES, spin=SPIN_VALUES, + L_to_nside_ratio=L_TO_NSIDE_RATIO_VALUES, reality=REALITY_VALUES, spmd=SPMD_VALUES, ) -def inverse(flm, precomps, method, L, L_lower, sampling, spin, reality, spmd): +def inverse( + flm, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd +): if method == "pyssht": f = pyssht.inverse(samples.flm_2d_to_1d(flm, L), L, spin, sampling.upper()) else: @@ -110,6 +129,7 @@ def inverse(flm, precomps, method, L, L_lower, sampling, spin, reality, spmd): L_lower=L_lower, precomps=precomps, spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), sampling=sampling, reality=reality, method=method, From 9a8d721b50ed6a808de20259884e8a3e116d7ca3 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 09:37:13 +0000 Subject: [PATCH 04/22] Expose parameter to control number of iterations in spherical benchmarks --- benchmarks/spherical.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index 31cdc70d..d2bab4f6 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -16,6 +16,7 @@ METHOD_VALUES = ["numpy", "jax"] REALITY_VALUES = [True] SPMD_VALUES = [False] +N_ITER_VALUES = [None] def _jax_arrays_to_numpy(precomps): @@ -26,7 +27,9 @@ def _get_nside(sampling, L, L_to_nside_ratio): return None if sampling != "healpix" else L // L_to_nside_ratio -def setup_forward(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): +def setup_forward( + method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd, n_iter +): if reality and spin != 0: skip("Reality only valid for scalar fields (spin=0).") if spmd and method != "jax": @@ -62,9 +65,20 @@ def setup_forward(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, L_to_nside_ratio=L_TO_NSIDE_RATIO_VALUES, reality=REALITY_VALUES, spmd=SPMD_VALUES, + n_iter=N_ITER_VALUES, ) def forward( - f, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd + f, + precomps, + method, + L, + L_lower, + sampling, + spin, + L_to_nside_ratio, + reality, + spmd, + n_iter, ): if method == "pyssht": flm = pyssht.forward(f, L, spin, sampling.upper()) @@ -80,6 +94,7 @@ def forward( reality=reality, method=method, spmd=spmd, + iter=n_iter, ) if method == "jax": flm.block_until_ready() From b22111baf98a33d8c821c4e019f48b9f1a45b030 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 10:12:47 +0000 Subject: [PATCH 05/22] Only override parameters defined for benchmark --- benchmarks/benchmarking.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 9fead05b..d43060fb 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -317,7 +317,9 @@ def run_benchmarks( print(benchmark.__name__) parameters = benchmark.parameters.copy() if parameter_overrides is not None: - parameters.update(parameter_overrides) + for parameter_name, parameter_values in parameter_overrides.items(): + if parameter_name in parameters: + parameters[parameter_name] = parameter_values for parameter_set in _dict_product(parameters): try: precomputes, reference_output = benchmark.setup(**parameter_set) From fdcba06c2481e08ac924ee56893576d1667fbd48 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 10:15:07 +0000 Subject: [PATCH 06/22] Remove pyssht dependency in benchmarks --- benchmarks/spherical.py | 67 +++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index d2bab4f6..cd61bec0 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -1,12 +1,11 @@ """Benchmarks for on-the-fly spherical transforms.""" +import jax import numpy as np -import pyssht from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip import s2fft from s2fft.recursions.price_mcewen import generate_precomputes_jax -from s2fft.sampling import s2_samples as samples L_VALUES = [8, 16, 32, 64, 128, 256] L_LOWER_VALUES = [0] @@ -80,25 +79,20 @@ def forward( spmd, n_iter, ): - if method == "pyssht": - flm = pyssht.forward(f, L, spin, sampling.upper()) - else: - flm = s2fft.transforms.spherical.forward( - f=f, - L=L, - L_lower=L_lower, - precomps=precomps, - spin=spin, - nside=_get_nside(sampling, L, L_to_nside_ratio), - sampling=sampling, - reality=reality, - method=method, - spmd=spmd, - iter=n_iter, - ) - if method == "jax": - flm.block_until_ready() - return flm + flm = s2fft.transforms.spherical.forward( + f=f, + L=L, + L_lower=L_lower, + precomps=precomps, + spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), + sampling=sampling, + reality=reality, + method=method, + spmd=spmd, + iter=n_iter, + ) + return flm.block_until_ready() if isinstance(flm, jax.Array) else flm def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): @@ -135,24 +129,19 @@ def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, def inverse( flm, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd ): - if method == "pyssht": - f = pyssht.inverse(samples.flm_2d_to_1d(flm, L), L, spin, sampling.upper()) - else: - f = s2fft.transforms.spherical.inverse( - flm=flm, - L=L, - L_lower=L_lower, - precomps=precomps, - spin=spin, - nside=_get_nside(sampling, L, L_to_nside_ratio), - sampling=sampling, - reality=reality, - method=method, - spmd=spmd, - ) - if method == "jax": - f.block_until_ready() - return f + f = s2fft.transforms.spherical.inverse( + flm=flm, + L=L, + L_lower=L_lower, + precomps=precomps, + spin=spin, + nside=_get_nside(sampling, L, L_to_nside_ratio), + sampling=sampling, + reality=reality, + method=method, + spmd=spmd, + ) + return f.block_until_ready() if isinstance(f, jax.Array) else f if __name__ == "__main__": From ebe51daaf70d4bf4b8b334236b3926df58e11643 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 10:18:46 +0000 Subject: [PATCH 07/22] Pass through method to inverse transform in forward setup --- benchmarks/spherical.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index cd61bec0..e9e83d52 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -42,6 +42,7 @@ def setup_forward( spin=spin, nside=nside, sampling=sampling, + method=method, reality=reality, spmd=spmd, L_lower=L_lower, From d5bbc1e2b8f19b7cd7f1a9040f3876824f48121c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 10:38:15 +0000 Subject: [PATCH 08/22] Remove flag for run once and discard in benchmarks --- benchmarks/benchmarking.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index d43060fb..48a0b4f5 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -249,15 +249,6 @@ def _parse_cli_arguments(description): parser.add_argument( "-output-file", type=Path, help="File path to write JSON formatted results to." ) - parser.add_argument( - "--run-once-and-discard", - action="store_true", - help=( - "Run benchmark function once first without recording time to " - "ignore the effect of any initial one-off costs such as just-in-time " - "compilation." - ), - ) return parser.parse_args() @@ -288,7 +279,6 @@ def run_benchmarks( number_repeats, print_results=True, parameter_overrides=None, - run_once_and_discard=False, ): """Run a set of benchmarks. @@ -302,9 +292,6 @@ def run_benchmarks( print_results: Whether to print benchmark results to stdout. parameter_overrides: Dictionary specifying any overrides for parameter values set in `benchmark` decorator. - run_once_and_discard: Whether to run benchmark function once first without - recording time to ignore the effect of any initial one-off costs such as - just-in-time compilation. Returns: Dictionary containing timing (and potentially memory usage) results for each @@ -324,8 +311,10 @@ def run_benchmarks( try: precomputes, reference_output = benchmark.setup(**parameter_set) benchmark_function = partial(benchmark, **precomputes, **parameter_set) - if run_once_and_discard or reference_output is not None: - output = benchmark_function() + # Run benchmark once without timing to record output for potentially + # computing numerical error and to remove effect of any one-off costs + # such as just-in-time compilation when timing + output = benchmark_function() run_times = [ time / number_runs for time in timeit.repeat( @@ -364,7 +353,7 @@ def run_benchmarks( else "" ) + ( - f", round-trip error: {results_entry['error']:#7.2g}" + f", max(abs(error)): {results_entry['error']:#7.2g}" if "error" in results_entry else "" ) @@ -398,7 +387,6 @@ def parse_args_collect_and_run_benchmarks(module=None): number_runs=args.number_runs, number_repeats=args.repeats, parameter_overrides=parameter_overrides, - run_once_and_discard=args.run_once_and_discard, ) if args.output_file is not None: package_versions = { From 58e16d0c5d751eebeae8fbc3a62b80d4a29252f6 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 10:52:17 +0000 Subject: [PATCH 09/22] Allow filtering which benchmarks in module are run --- benchmarks/benchmarking.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 48a0b4f5..1a9aaf11 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -249,6 +249,11 @@ def _parse_cli_arguments(description): parser.add_argument( "-output-file", type=Path, help="File path to write JSON formatted results to." ) + parser.add_argument( + "-benchmarks", + nargs="+", + help="Names of benchmark functions to run. All benchmarks are run if omitted.", + ) return parser.parse_args() @@ -261,16 +266,22 @@ def _is_benchmark(object): ) -def collect_benchmarks(module): +def collect_benchmarks(module, benchmark_names): """Collect all benchmark functions from a module. Args: module: Python module containing benchmark functions. + benchmark_names: List of benchmark names to collect or `None` if all benchmarks + in module to be collected. Returns: List of functions in module with `is_benchmark` attribute set to `True`. """ - return [function for name, function in inspect.getmembers(module, _is_benchmark)] + return [ + function + for name, function in inspect.getmembers(module, _is_benchmark) + if benchmark_names is None or name in benchmark_names + ] def run_benchmarks( @@ -383,7 +394,7 @@ def parse_args_collect_and_run_benchmarks(module=None): args = _parse_cli_arguments(module.__doc__) parameter_overrides = _parse_parameter_overrides(args.parameter_overrides) results = run_benchmarks( - benchmarks=collect_benchmarks(module), + benchmarks=collect_benchmarks(module, args.benchmarks), number_runs=args.number_runs, number_repeats=args.repeats, parameter_overrides=parameter_overrides, From 4fb58ad168708d86df8aee0fd27b70a203a2022c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 12:05:04 +0000 Subject: [PATCH 10/22] Update benchmarking usage example in README --- benchmarks/README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index ab449949..2ef7000f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -12,7 +12,6 @@ If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/) is installed additional information about CPU of system benchmarks are run on will be recorded in JSON output. - ## Description The benchmark scripts are as follows: @@ -40,7 +39,7 @@ display the usage message: ``` usage: spherical.py [-h] [-number-runs NUMBER_RUNS] [-repeats REPEATS] [-parameter-overrides [PARAMETER_OVERRIDES ...]] [-output-file OUTPUT_FILE] - [--run-once-and-discard] + [-benchmarks BENCHMARKS [BENCHMARKS ...]] Benchmarks for on-the-fly spherical transforms. @@ -55,9 +54,8 @@ options: parameters. (default: None) -output-file OUTPUT_FILE File path to write JSON formatted results to. (default: None) - --run-once-and-discard - Run benchmark function once first without recording time to ignore the effect of any initial - one-off costs such as just-in-time compilation. (default: False) + -benchmarks BENCHMARKS [BENCHMARKS ...] + Names of benchmark functions to run. All benchmarks are run if omitted. (default: None) ``` For example to run the spherical transform benchmarks using only the JAX implementations, @@ -65,7 +63,7 @@ running on a CPU (in double-precision) for `L` values 64, 128, 256, 512 and 1024 would run from the root of the repository: ```sh -JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py --run-once-and-discard -p L 64 128 256 512 1024 -p method jax +JAX_PLATFORM_NAME=cpu JAX_ENABLE_X64=1 python benchmarks/spherical.py -p L 64 128 256 512 1024 -p method jax ``` Note the usage of environment variables `JAX_PLATFORM_NAME` and `JAX_ENABLE_X64` to From 5534f6ad3c4ad11b950e1ec63d30779609371219 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 12:05:14 +0000 Subject: [PATCH 11/22] Split up longer functions --- benchmarks/benchmarking.py | 162 +++++++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 61 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 1a9aaf11..5805f969 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -193,6 +193,29 @@ def _parameters_string(parameters): return "(" + ", ".join(f"{name}: {val}" for name, val in parameters.items()) + ")" +def _format_results_entry(results_entry): + """Format benchmark results entry as a string for printing.""" + return ( + ( + f"{_parameters_string(results_entry['parameters']):>40}: \n " + if len(results_entry["parameters"]) != 0 + else " " + ) + + f"min(time): {min(results_entry['times / s']):>#7.2g}s, " + + f"max(time): {max(results_entry['times / s']):>#7.2g}s" + + ( + f", peak mem.: {results_entry['peak_memory / MiB']:>#7.2g}MiB" + if "peak_memory / MiB" in results_entry + else "" + ) + + ( + f", max(abs(error)): {results_entry['error']:#7.2g}" + if "error" in results_entry + else "" + ) + ) + + def _dict_product(dicts): """Generator corresponding to Cartesian product of dictionaries.""" return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values())) @@ -284,6 +307,33 @@ def collect_benchmarks(module, benchmark_names): ] +def measure_peak_memory_usage(benchmark_function, interval): + """Measure peak memory usage in mebibytes (MiB) of a function using memory_profiler. + + Args: + benchmark_function: Function to benchmark peak memory usage of. + interval: Interval in seconds at which memory measurements are collected. + + Returns: + Peak memory usage measure in mebibytes (MiB). + """ + baseline_memory = memory_profiler.memory_usage( + lambda: None, + max_usage=True, + include_children=True, + ) + return max( + memory_profiler.memory_usage( + benchmark_function, + interval=interval, + max_usage=True, + include_children=True, + ) + - baseline_memory, + 0, + ) + + def run_benchmarks( benchmarks, number_runs, @@ -332,43 +382,17 @@ def run_benchmarks( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results_entry = {**parameter_set, "times / s": run_times} + results_entry = {"parameters": parameter_set, "times / s": run_times} if reference_output is not None and output is not None: results_entry["error"] = abs(reference_output - output).max() if MEMORY_PROFILER_AVAILABLE: - baseline_memory = memory_profiler.memory_usage(max_usage=True) - peak_memory = ( - memory_profiler.memory_usage( - benchmark_function, - interval=max(run_times) * number_repeats, - max_usage=True, - max_iterations=number_repeats, - include_children=True, - ) - - baseline_memory + results_entry["peak_memory / MiB"] = measure_peak_memory_usage( + benchmark_function, + interval=min(run_times) / 20, ) - results_entry["peak_memory / MiB"] = peak_memory results[benchmark.__name__].append(results_entry) if print_results: - print( - ( - f"{_parameters_string(parameter_set):>40}: \n " - if len(parameter_set) != 0 - else " " - ) - + f"min(time): {min(run_times):>#7.2g}s, " - + f"max(time): {max(run_times):>#7.2g}s" - + ( - f", peak mem.: {peak_memory:>#7.2g}MiB" - if MEMORY_PROFILER_AVAILABLE - else "" - ) - + ( - f", max(abs(error)): {results_entry['error']:#7.2g}" - if "error" in results_entry - else "" - ) - ) + print(_format_results_entry(results_entry)) except SkipBenchmarkException as e: if print_results: print( @@ -377,16 +401,56 @@ def run_benchmarks( return results +def get_system_info(): + """Get dictionary of metadata about system. + + Returns: + Dictionary with information about system, CPU and GPU devices (if present) and + Python environment and package versions. + """ + package_versions = { + f"{package}_version": _get_version_or_none(package) + for package in ("s2fft", "jax", "numpy") + } + return { + "architecture": platform.architecture(), + "machine": platform.machine(), + "node": platform.node(), + "processor": platform.processor(), + "python_version": platform.python_version(), + "release": platform.release(), + "system": platform.system(), + "cpu_info": _get_cpu_info(), + "gpu_info": _get_gpu_info(), + "cuda_info": _get_cuda_info(), + **package_versions, + } + + +def write_json_results_file(output_file, results, benchmark_module): + """Write benchmark results and system information to a file in JSON format. + + Args: + output_file: Path to file to write results to. + results: Dictionary of benchmark results from `run_benchmarks`. + benchmarks_module: Name of module containing benchmarks. + """ + with open(output_file, "w") as f: + output = { + "date_time": datetime.datetime.now().isoformat(), + "benchmark_module": benchmark_module, + "system_info": get_system_info(), + "results": results, + } + json.dump(output, f, indent=True) + + def parse_args_collect_and_run_benchmarks(module=None): """Collect and run all benchmarks in a module and parse command line arguments. Args: module: Module containing benchmarks to run. Defaults to module from which this function was called if not specified (set to `None`). - - Returns: - Dictionary containing timing (and potentially memory usage) results for each - parameters set of each benchmark function. """ if module is None: frame = inspect.stack()[1] @@ -397,32 +461,8 @@ def parse_args_collect_and_run_benchmarks(module=None): benchmarks=collect_benchmarks(module, args.benchmarks), number_runs=args.number_runs, number_repeats=args.repeats, + print_results=True, parameter_overrides=parameter_overrides, ) if args.output_file is not None: - package_versions = { - f"{package}_version": _get_version_or_none(package) - for package in ("s2fft", "jax", "numpy") - } - system_info = { - "architecture": platform.architecture(), - "machine": platform.machine(), - "node": platform.node(), - "processor": platform.processor(), - "python_version": platform.python_version(), - "release": platform.release(), - "system": platform.system(), - "cpu_info": _get_cpu_info(), - "gpu_info": _get_gpu_info(), - "cuda_info": _get_cuda_info(), - **package_versions, - } - with open(args.output_file, "w") as f: - output = { - "date_time": datetime.datetime.now().isoformat(), - "benchmark_module": module.__name__, - "system_info": system_info, - "results": results, - } - json.dump(output, f, indent=True) - return results + write_json_results_file(args.output_file, results, module.__name__) From 4e01158dfb9a1d1201a48a0bd77e767b853defdf Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 16:12:52 +0000 Subject: [PATCH 12/22] Use JAX AOT compilation to record cost + memory analysis estimates --- benchmarks/benchmarking.py | 122 ++++++++++++++++++++++------- benchmarks/precompute_spherical.py | 21 +++-- benchmarks/precompute_wigner.py | 22 +++--- benchmarks/spherical.py | 18 +++-- benchmarks/wigner.py | 24 +++--- 5 files changed, 134 insertions(+), 73 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 5805f969..9f33f93f 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -1,25 +1,40 @@ """Helper functions and classes for benchmarking. -Functions to be benchmarked in a module should be decorated with `benchmark` which -takes one positional argument corresponding to a function to peform any necessary -set up for the benchmarked function (returning a dictionary, potentially empty with -any precomputed values to pass to benchmark function as keyword arguments) and zero -or more keyword arguments specifying parameter names and values lists to benchmark -over (the Cartesian product of all specified parameter values is used). The benchmark -function is passed the union of any precomputed values outputted by the setup function -and the parameters values as keyword arguments. +Functions to be benchmarked in a module should be decorated with `benchmark` which takes +one positional argument corresponding to a function to perform any necessary set up for +the benchmarked function and zero or more keyword arguments specifying parameter names +and values lists to benchmark over (the Cartesian product of all specified parameter +values is used). + +The benchmark setup function should return an instance of the named tuple type +`BenchmarkSetup` which consists of a required dictionary entry, potentially empty with +any precomputed values to pass to benchmark function as keyword arguments, and two +further optional entries: the first a reference value for the output of the benchmarked +function to use to compute numerical error if applicable, defaulting to `None` +indicating no applicable reference value; and the second a flag indicating whether to +just-in-time compile the benchmark function using JAX's `jit` transform, defaulting to +`False`. + +The benchmark function is passed the union of any precomputed values returned by the +setup function and the parameters values as keyword arguments. If a reference output +value is set by the setup function the benchmark function should output the value to +compare to this reference value by computing the maximum absolute elementwise +difference. If the function is to be just-in-time compiled using JAX the value returned +by the benchmark function should be a JAX array on which the `block_until_ready` method +may be called to ensure the function only exits once the relevant computation has +completed (necessary due to JAX's asynchronous dispatch computation model). As a simple example, the following defines a benchmark for computing the mean of a list of numbers. ```Python import random -from benchmarking import benchmark +from benchmarking import BenchmarkSetup, benchmark def setup_mean(n): - return {"x": [random.random() for _ in range(n)]}, None + return BenchmarkSetup({"x": [random.random() for _ in range(n)]}) -@benchmark(setup_computation, n=[1, 2, 3, 4]) +@benchmark(setup_mean, n=[1, 2, 3, 4]) def mean(x, n): return sum(x) / n ``` @@ -29,12 +44,12 @@ def mean(x, n): ```Python import random -from benchmarking import benchmark, skip +from benchmarking import BenchmarkSetup, benchmark, skip def setup_mean(n): - return {"x": [random.random() for _ in range(n)]}, None + return BenchmarkSetup({"x": [random.random() for _ in range(n)]}) -@benchmark(setup_computation, n=[0, 1, 2, 3, 4]) +@benchmark(setup_mean, n=[0, 1, 2, 3, 4]) def mean(x, n): if n == 0: skip("number of items must be positive") @@ -48,17 +63,17 @@ def mean(x, n): to allow it to be executed as a script for runnning the benchmarks: ```Python - from benchmarking import benchmark, parse_args_collect_and_run_benchmarks ... if __name__ == "__main__": parse_args_collect_and_run_benchmarks() - ``` """ +from __future__ import annotations + import argparse import datetime import inspect @@ -70,6 +85,10 @@ def mean(x, n): from importlib.metadata import PackageNotFoundError, version from itertools import product from pathlib import Path +from typing import Any, NamedTuple + +import jax +import numpy as np try: import memory_profiler @@ -155,11 +174,19 @@ def skip(message): raise SkipBenchmarkException(message) -def benchmark(setup_=None, **parameters): +class BenchmarkSetup(NamedTuple): + """Structure containing data for setting up a benchmark function.""" + + arguments: dict[str, Any] + reference_output: None | jax.Array | np.ndarray = None + jit_benchmark: bool = False + + +def benchmark(setup=None, **parameters): """Decorator for defining a function to be benchmark. Args: - setup_: Function performing any necessary set up for benchmark, and the resource + setup: Function performing any necessary set up for benchmark, and the resource usage of which will not be tracked in benchmarking. The function should return a 2-tuple, with first entry a dictionary of values to pass to the benchmark as keyword arguments, corresponding to any precomputed values, @@ -175,13 +202,13 @@ def benchmark(setup_=None, **parameters): Decorator which marks function as benchmark and sets setup function and parameters attributes. To also record error of benchmarked function in terms of maximum absolute elementwise difference between output and reference value - returned by `setup_` function, the decorated benchmark function should return + returned by `setup` function, the decorated benchmark function should return the numerical value to measure the error for. """ def decorator(function): function.is_benchmark = True - function.setup = setup_ if setup_ is not None else lambda: {} + function.setup = setup if setup is not None else lambda: {} function.parameters = parameters return function @@ -204,15 +231,21 @@ def _format_results_entry(results_entry): + f"min(time): {min(results_entry['times / s']):>#7.2g}s, " + f"max(time): {max(results_entry['times / s']):>#7.2g}s" + ( - f", peak mem.: {results_entry['peak_memory / MiB']:>#7.2g}MiB" + f", peak memory: {results_entry['peak_memory / MiB']:>#7.2g}MiB" if "peak_memory / MiB" in results_entry else "" ) + ( - f", max(abs(error)): {results_entry['error']:#7.2g}" + f", max(abs(error)): {results_entry['error']:>#7.2g}" if "error" in results_entry else "" ) + + ( + f", floating point ops: {results_entry['cost_analysis']['flops']:>#7.2g}" + f", mem access: {results_entry['cost_analysis']['bytes_accessed']:>#7.2g}B" + if "cost_analysis" in results_entry + else "" + ) ) @@ -334,6 +367,33 @@ def measure_peak_memory_usage(benchmark_function, interval): ) +def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry): + """Compile a JAX benchmark function and extract cost estimates if available.""" + compiled_benchmark_function = jax.jit(benchmark_function).lower().compile() + cost_analysis = compiled_benchmark_function.cost_analysis() + if cost_analysis is not None and isinstance(cost_analysis, list): + results_entry["cost_analysis"] = { + "flops": cost_analysis[0].get("flops"), + "bytes_accessed": cost_analysis[0].get("bytes accessed"), + } + memory_analysis = compiled_benchmark_function.memory_analysis() + if memory_analysis is not None: + results_entry["memory_analysis"] = { + prefix + base_key: getattr(memory_analysis, prefix + base_key, None) + for prefix in ("", "host_") + for base_key in ( + "alias_size_in_bytes", + "argument_size_in_bytes", + "generated_code_size_in_bytes", + "output_size_in_bytes", + "temp_size_in_bytes", + ) + } + # Ensure block_until_ready called on benchmark output due to JAX's asynchronous + # dispatch model: https://jax.readthedocs.io/en/latest/async_dispatch.html + return lambda: compiled_benchmark_function().block_until_ready() + + def run_benchmarks( benchmarks, number_runs, @@ -370,21 +430,25 @@ def run_benchmarks( parameters[parameter_name] = parameter_values for parameter_set in _dict_product(parameters): try: - precomputes, reference_output = benchmark.setup(**parameter_set) - benchmark_function = partial(benchmark, **precomputes, **parameter_set) + args, reference_output, jit_benchmark = benchmark.setup(**parameter_set) + benchmark_function = partial(benchmark, **args, **parameter_set) + results_entry = {"parameters": parameter_set} + if jit_benchmark: + benchmark_function = _compile_jax_benchmark_and_analyse( + benchmark_function, results_entry + ) # Run benchmark once without timing to record output for potentially - # computing numerical error and to remove effect of any one-off costs - # such as just-in-time compilation when timing + # computing numerical error output = benchmark_function() + if reference_output is not None and output is not None: + results_entry["error"] = abs(reference_output - output).max() run_times = [ time / number_runs for time in timeit.repeat( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results_entry = {"parameters": parameter_set, "times / s": run_times} - if reference_output is not None and output is not None: - results_entry["error"] = abs(reference_output - output).max() + results_entry["times / s"] = run_times if MEMORY_PROFILER_AVAILABLE: results_entry["peak_memory / MiB"] = measure_peak_memory_usage( benchmark_function, diff --git a/benchmarks/precompute_spherical.py b/benchmarks/precompute_spherical.py index 6202fe29..bec1dd7c 100644 --- a/benchmarks/precompute_spherical.py +++ b/benchmarks/precompute_spherical.py @@ -1,7 +1,12 @@ """Benchmarks for precompute spherical transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, + skip, +) import s2fft import s2fft.precompute_transforms @@ -39,7 +44,7 @@ def setup_forward(method, L, sampling, spin, reality, recursion): forward=True, recursion=recursion, ) - return {"f": f, "kernel": kernel}, flm + return BenchmarkSetup({"f": f, "kernel": kernel}, flm, "jax" in method) @benchmark( @@ -52,7 +57,7 @@ def setup_forward(method, L, sampling, spin, reality, recursion): recursion=RECURSION_VALUES, ) def forward(f, kernel, method, L, sampling, spin, reality, recursion): - flm = s2fft.precompute_transforms.spherical.forward( + return s2fft.precompute_transforms.spherical.forward( f=f, L=L, spin=spin, @@ -61,9 +66,6 @@ def forward(f, kernel, method, L, sampling, spin, reality, recursion): reality=reality, method=method, ) - if method == "jax": - flm.block_until_ready() - return flm def setup_inverse(method, L, sampling, spin, reality, recursion): @@ -84,7 +86,7 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): forward=False, recursion=recursion, ) - return {"flm": flm, "kernel": kernel}, None + return BenchmarkSetup({"flm": flm, "kernel": kernel}, None, "jax" in method) @benchmark( @@ -97,7 +99,7 @@ def setup_inverse(method, L, sampling, spin, reality, recursion): recursion=RECURSION_VALUES, ) def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): - f = s2fft.precompute_transforms.spherical.inverse( + return s2fft.precompute_transforms.spherical.inverse( flm=flm, L=L, spin=spin, @@ -106,9 +108,6 @@ def inverse(flm, kernel, method, L, sampling, spin, reality, recursion): reality=reality, method=method, ) - if method == "jax": - f.block_until_ready() - return f if __name__ == "__main__": diff --git a/benchmarks/precompute_wigner.py b/benchmarks/precompute_wigner.py index 60a401e7..73971d43 100644 --- a/benchmarks/precompute_wigner.py +++ b/benchmarks/precompute_wigner.py @@ -1,7 +1,11 @@ """Benchmarks for precompute Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, +) import s2fft import s2fft.precompute_transforms @@ -29,13 +33,13 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): ) kernel_function = ( s2fft.precompute_transforms.construct.wigner_kernel_jax - if method == "jax" + if "jax" in method else s2fft.precompute_transforms.construct.wigner_kernel ) kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode ) - return {"f": f, "kernel": kernel}, flmn + return BenchmarkSetup({"f": f, "kernel": kernel}, flmn, "jax" in method) @benchmark( @@ -49,7 +53,7 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode): mode=MODE_VALUES, ) def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): - flmn = s2fft.precompute_transforms.wigner.forward( + return s2fft.precompute_transforms.wigner.forward( f=f, L=L, N=N, @@ -58,9 +62,6 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode): reality=reality, method=method, ) - if method == "jax": - flmn.block_until_ready() - return flmn def setup_inverse(method, L, N, L_lower, sampling, reality, mode): @@ -74,7 +75,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality, mode): kernel = kernel_function( L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode ) - return {"flmn": flmn, "kernel": kernel}, None + return BenchmarkSetup({"flmn": flmn, "kernel": kernel}, None, "jax" in method) @benchmark( @@ -88,7 +89,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality, mode): mode=MODE_VALUES, ) def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): - f = s2fft.precompute_transforms.wigner.inverse( + return s2fft.precompute_transforms.wigner.inverse( flmn=flmn, L=L, N=N, @@ -97,9 +98,6 @@ def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode): reality=reality, method=method, ) - if method == "jax": - f.block_until_ready() - return f if __name__ == "__main__": diff --git a/benchmarks/spherical.py b/benchmarks/spherical.py index e9e83d52..d0ea078d 100644 --- a/benchmarks/spherical.py +++ b/benchmarks/spherical.py @@ -1,8 +1,12 @@ """Benchmarks for on-the-fly spherical transforms.""" -import jax import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, + skip, +) import s2fft from s2fft.recursions.price_mcewen import generate_precomputes_jax @@ -52,7 +56,7 @@ def setup_forward( ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"f": f, "precomps": precomps}, flm + return BenchmarkSetup({"f": f, "precomps": precomps}, flm, "jax" in method) @benchmark( @@ -80,7 +84,7 @@ def forward( spmd, n_iter, ): - flm = s2fft.transforms.spherical.forward( + return s2fft.transforms.spherical.forward( f=f, L=L, L_lower=L_lower, @@ -93,7 +97,6 @@ def forward( spmd=spmd, iter=n_iter, ) - return flm.block_until_ready() if isinstance(flm, jax.Array) else flm def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd): @@ -113,7 +116,7 @@ def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, ) if method == "numpy": precomps = _jax_arrays_to_numpy(precomps) - return {"flm": flm, "precomps": precomps}, None + return BenchmarkSetup({"flm": flm, "precomps": precomps}, None, "jax" in method) @benchmark( @@ -130,7 +133,7 @@ def setup_inverse(method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, def inverse( flm, precomps, method, L, L_lower, sampling, spin, L_to_nside_ratio, reality, spmd ): - f = s2fft.transforms.spherical.inverse( + return s2fft.transforms.spherical.inverse( flm=flm, L=L, L_lower=L_lower, @@ -142,7 +145,6 @@ def inverse( method=method, spmd=spmd, ) - return f.block_until_ready() if isinstance(f, jax.Array) else f if __name__ == "__main__": diff --git a/benchmarks/wigner.py b/benchmarks/wigner.py index af71033b..89655c79 100644 --- a/benchmarks/wigner.py +++ b/benchmarks/wigner.py @@ -1,7 +1,11 @@ """Benchmarks for on-the-fly Wigner-d transforms.""" import numpy as np -from benchmarking import benchmark, parse_args_collect_and_run_benchmarks +from benchmarking import ( + BenchmarkSetup, + benchmark, + parse_args_collect_and_run_benchmarks, +) import s2fft from s2fft.base_transforms import wigner as base_wigner @@ -31,13 +35,13 @@ def setup_forward(method, L, L_lower, N, sampling, reality): ) generate_precomputes = ( generate_precomputes_wigner_jax - if method == "jax" + if "jax" in method else generate_precomputes_wigner ) precomps = generate_precomputes( L, N, sampling, forward=True, reality=reality, L_lower=L_lower ) - return {"f": f, "precomps": precomps}, flmn + return BenchmarkSetup({"f": f, "precomps": precomps}, flmn, "jax" in method) @benchmark( @@ -50,7 +54,7 @@ def setup_forward(method, L, L_lower, N, sampling, reality): reality=REALITY_VALUES, ) def forward(f, precomps, method, L, L_lower, N, sampling, reality): - flmn = s2fft.transforms.wigner.forward( + return s2fft.transforms.wigner.forward( f=f, L=L, N=N, @@ -60,9 +64,6 @@ def forward(f, precomps, method, L, L_lower, N, sampling, reality): precomps=precomps, L_lower=L_lower, ) - if method == "jax": - flmn.block_until_ready() - return flmn def setup_inverse(method, L, L_lower, N, sampling, reality): @@ -70,13 +71,13 @@ def setup_inverse(method, L, L_lower, N, sampling, reality): flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality) generate_precomputes = ( generate_precomputes_wigner_jax - if method == "jax" + if "jax" in method else generate_precomputes_wigner ) precomps = generate_precomputes( L, N, sampling, forward=False, reality=reality, L_lower=L_lower ) - return {"flmn": flmn, "precomps": precomps}, None + return BenchmarkSetup({"flmn": flmn, "precomps": precomps}, None, "jax" in method) @benchmark( @@ -89,7 +90,7 @@ def setup_inverse(method, L, L_lower, N, sampling, reality): reality=REALITY_VALUES, ) def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): - f = s2fft.transforms.wigner.inverse( + return s2fft.transforms.wigner.inverse( flmn=flmn, L=L, N=N, @@ -99,9 +100,6 @@ def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality): precomps=precomps, L_lower=L_lower, ) - if method == "jax": - f.block_until_ready() - return f if __name__ == "__main__": From 17918286ba2a08d6320763c8664d0204e103b391 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 16:18:46 +0000 Subject: [PATCH 13/22] Normalize benchmark results key naming --- benchmarks/benchmarking.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 9f33f93f..e1f62d6d 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -120,13 +120,12 @@ def _get_cpu_info(): return None -def _get_gpu_memory_mebibytes(device): - """Try to get GPU memory available in mebibytes (MiB).""" +def _get_gpu_memory_in_bytes(device): + """Try to get GPU memory available in bytes.""" memory_stats = device.memory_stats() if memory_stats is None: return None - bytes_limit = memory_stats.get("bytes_limit") - return bytes_limit // 2**20 if bytes_limit is not None else None + return memory_stats.get("bytes_limit") def _get_gpu_info(): @@ -137,7 +136,7 @@ def _get_gpu_info(): return [ { "kind": d.device_kind, - "memory_available / MiB": _get_gpu_memory_mebibytes(d), + "memory_available_in_bytes": _get_gpu_memory_in_bytes(d), } for d in jax.devices() if d.platform == "gpu" @@ -228,15 +227,15 @@ def _format_results_entry(results_entry): if len(results_entry["parameters"]) != 0 else " " ) - + f"min(time): {min(results_entry['times / s']):>#7.2g}s, " - + f"max(time): {max(results_entry['times / s']):>#7.2g}s" + + f"min(run times): {min(results_entry['run_times_in_seconds']):>#7.2g}s, " + + f"max(run times): {max(results_entry['run_times_in_seconds']):>#7.2g}s" + ( - f", peak memory: {results_entry['peak_memory / MiB']:>#7.2g}MiB" - if "peak_memory / MiB" in results_entry + f", peak memory: {results_entry['peak_memory_in_bytes']:>#7.2g}B" + if "peak_memory_in_bytes" in results_entry else "" ) + ( - f", max(abs(error)): {results_entry['error']:>#7.2g}" + f", max(abs(error)): {results_entry['max_abs_error']:>#7.2g}" if "error" in results_entry else "" ) @@ -441,19 +440,21 @@ def run_benchmarks( # computing numerical error output = benchmark_function() if reference_output is not None and output is not None: - results_entry["error"] = abs(reference_output - output).max() + results_entry["max_abs_error"] = abs( + reference_output - output + ).max() run_times = [ time / number_runs for time in timeit.repeat( benchmark_function, number=number_runs, repeat=number_repeats ) ] - results_entry["times / s"] = run_times + results_entry["run_times_in_seconds"] = run_times if MEMORY_PROFILER_AVAILABLE: - results_entry["peak_memory / MiB"] = measure_peak_memory_usage( + results_entry["peak_memory_in_bytes"] = measure_peak_memory_usage( benchmark_function, interval=min(run_times) / 20, - ) + ) * (2**20) results[benchmark.__name__].append(results_entry) if print_results: print(_format_results_entry(results_entry)) From 651c14f86099676fde14a25923cc1ba440292bab Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 16:48:21 +0000 Subject: [PATCH 14/22] Use tracemalloc instead of memory_profiler for measuring CPU memory --- benchmarks/README.md | 7 ++--- benchmarks/benchmarking.py | 62 ++++++++++++++++---------------------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 2ef7000f..bb50c0be 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,13 +1,10 @@ # Benchmarks for `s2fft` -Scripts for benchmarking `s2fft` with `timeit` (and optionally `memory_profiler`). +Scripts for benchmarking `s2fft` transforms. Measures time to compute transforms for grids of parameters settings, optionally outputting the results to a JSON file to allow comparing performance over versions -and/or systems. -If the [`memory_profiler` package](https://github.com/pythonprofilers/memory_profiler) -is installed an estimate of the peak (main) memory usage of the benchmarked functions -will also be recorded. +and/or systems. If the [`py-cpuinfo` package](https://pypi.org/project/py-cpuinfo/) is installed additional information about CPU of system benchmarks are run on will be recorded in JSON output. diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index e1f62d6d..7eb6cd84 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -75,11 +75,13 @@ def mean(x, n): from __future__ import annotations import argparse +import contextlib import datetime import inspect import json import platform import timeit +import tracemalloc from ast import literal_eval from functools import partial from importlib.metadata import PackageNotFoundError, version @@ -90,13 +92,6 @@ def mean(x, n): import jax import numpy as np -try: - import memory_profiler - - MEMORY_PROFILER_AVAILABLE = True -except ImportError: - MEMORY_PROFILER_AVAILABLE = False - class SkipBenchmarkException(Exception): """Exception to be raised to skip benchmark for some parameter set.""" @@ -230,8 +225,8 @@ def _format_results_entry(results_entry): + f"min(run times): {min(results_entry['run_times_in_seconds']):>#7.2g}s, " + f"max(run times): {max(results_entry['run_times_in_seconds']):>#7.2g}s" + ( - f", peak memory: {results_entry['peak_memory_in_bytes']:>#7.2g}B" - if "peak_memory_in_bytes" in results_entry + f", peak memory: {results_entry['traced_memory_peak_in_bytes']:>#7.2g}B" + if "traced_memory_peak_in_bytes" in results_entry else "" ) + ( @@ -339,31 +334,27 @@ def collect_benchmarks(module, benchmark_names): ] -def measure_peak_memory_usage(benchmark_function, interval): - """Measure peak memory usage in mebibytes (MiB) of a function using memory_profiler. +@contextlib.contextmanager +def trace_memory_allocations(n_frames=1): + """Context manager for tracing memory allocations in managed with block. + + Returns a thunk (zero-argument function) which can be called on exit from with block + to get tuple of current size and peak size of memory blocks traced in bytes. Args: - benchmark_function: Function to benchmark peak memory usage of. - interval: Interval in seconds at which memory measurements are collected. + n_frames: Limit on depth of frames to trace memory allocations in. Returns: - Peak memory usage measure in mebibytes (MiB). + A thunk (zero-argument function) which can be called on exit from with block to + get tuple of current size and peak size of memory blocks traced in bytes. """ - baseline_memory = memory_profiler.memory_usage( - lambda: None, - max_usage=True, - include_children=True, - ) - return max( - memory_profiler.memory_usage( - benchmark_function, - interval=interval, - max_usage=True, - include_children=True, - ) - - baseline_memory, - 0, - ) + tracemalloc.start(n_frames) + current_size, peak_size = None, None + try: + yield lambda: (current_size, peak_size) + current_size, peak_size = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry): @@ -437,8 +428,12 @@ def run_benchmarks( benchmark_function, results_entry ) # Run benchmark once without timing to record output for potentially - # computing numerical error - output = benchmark_function() + # computing numerical error and trace memory usage + with trace_memory_allocations() as traced_memory: + output = benchmark_function() + current_size, peak_size = traced_memory() + results_entry["traced_memory_final_in_bytes"] = current_size + results_entry["traced_memory_peak_in_bytes"] = peak_size if reference_output is not None and output is not None: results_entry["max_abs_error"] = abs( reference_output - output @@ -450,11 +445,6 @@ def run_benchmarks( ) ] results_entry["run_times_in_seconds"] = run_times - if MEMORY_PROFILER_AVAILABLE: - results_entry["peak_memory_in_bytes"] = measure_peak_memory_usage( - benchmark_function, - interval=min(run_times) / 20, - ) * (2**20) results[benchmark.__name__].append(results_entry) if print_results: print(_format_results_entry(results_entry)) From 1a6cce075529f2fbd2f185f249c63d61ba44422e Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 23 Jan 2025 17:06:40 +0000 Subject: [PATCH 15/22] Add type hints to benchmarking module --- benchmarks/benchmarking.py | 73 ++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 7eb6cd84..628e7889 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -83,10 +83,12 @@ def mean(x, n): import timeit import tracemalloc from ast import literal_eval +from collections.abc import Callable, Iterable from functools import partial from importlib.metadata import PackageNotFoundError, version from itertools import product from pathlib import Path +from types import ModuleType from typing import Any, NamedTuple import jax @@ -97,7 +99,7 @@ class SkipBenchmarkException(Exception): """Exception to be raised to skip benchmark for some parameter set.""" -def _get_version_or_none(package_name): +def _get_version_or_none(package_name: str) -> str | None: """Get installed version of package or `None` if package not found.""" try: return version(package_name) @@ -115,7 +117,7 @@ def _get_cpu_info(): return None -def _get_gpu_memory_in_bytes(device): +def _get_gpu_memory_in_bytes(device: jax.Device) -> int | None: """Try to get GPU memory available in bytes.""" memory_stats = device.memory_stats() if memory_stats is None: @@ -123,7 +125,7 @@ def _get_gpu_memory_in_bytes(device): return memory_stats.get("bytes_limit") -def _get_gpu_info(): +def _get_gpu_info() -> dict[str, str | int]: """Get details of GPU devices available from JAX or None if JAX not available.""" try: import jax @@ -140,7 +142,7 @@ def _get_gpu_info(): return None -def _get_cuda_info(): +def _get_cuda_info() -> dict[str, str]: """Try to get information on versions of CUDA libraries.""" try: from jax._src.lib import cuda_versions @@ -159,7 +161,7 @@ def _get_cuda_info(): return None -def skip(message): +def skip(message: str) -> None: """Skip benchmark for a particular parameter set with explanatory message. Args: @@ -176,17 +178,22 @@ class BenchmarkSetup(NamedTuple): jit_benchmark: bool = False -def benchmark(setup=None, **parameters): +def benchmark( + setup: Callable[..., BenchmarkSetup] | None = None, **parameters +) -> Callable: """Decorator for defining a function to be benchmark. Args: setup: Function performing any necessary set up for benchmark, and the resource usage of which will not be tracked in benchmarking. The function should - return a 2-tuple, with first entry a dictionary of values to pass to the - benchmark as keyword arguments, corresponding to any precomputed values, - and the second entry optionally a reference value specifying the expected - 'true' numerical output of the behchmarked function to allow computing - numerical error, or `None` if there is no relevant reference value. + return an instance of `BenchmarkSetup` named tuple, with first entry a + dictionary of values to pass to the benchmark as keyword arguments, + corresponding to any precomputed values, the second entry optionally a + reference value specifying the expected 'true' numerical output of the + behchmarked function to allow computing numerical error, or `None` if there + is no relevant reference value and third entry a boolean flag indicating + whether to use JAX's just-in-time compilation transform to benchmark + function. Kwargs: Parameter names and associated lists of values over which to run benchmark. @@ -209,12 +216,12 @@ def decorator(function): return decorator -def _parameters_string(parameters): +def _parameters_string(parameters: dict) -> str: """Format parameter values as string for printing benchmark results.""" return "(" + ", ".join(f"{name}: {val}" for name, val in parameters.items()) + ")" -def _format_results_entry(results_entry): +def _format_results_entry(results_entry: dict) -> str: """Format benchmark results entry as a string for printing.""" return ( ( @@ -243,12 +250,12 @@ def _format_results_entry(results_entry): ) -def _dict_product(dicts): +def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]: """Generator corresponding to Cartesian product of dictionaries.""" return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values())) -def _parse_value(value): +def _parse_value(value: str) -> Any: """Parse a value passed at command line as a Python literal or string as fallback""" try: return literal_eval(value) @@ -256,7 +263,7 @@ def _parse_value(value): return str(value) -def _parse_parameter_overrides(parameter_overrides): +def _parse_parameter_overrides(parameter_overrides: list[str]) -> dict[str, Any]: """Parse any parameter override values passed as command line arguments""" return ( { @@ -268,7 +275,7 @@ def _parse_parameter_overrides(parameter_overrides): ) -def _parse_cli_arguments(description): +def _parse_cli_arguments(description: str) -> argparse.Namespace: """Parse command line arguments passed for controlling benchmark runs""" parser = argparse.ArgumentParser( description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -307,7 +314,7 @@ def _parse_cli_arguments(description): return parser.parse_args() -def _is_benchmark(object): +def _is_benchmark(object: Any) -> bool: """Predicate for testing whether an object is a benchmark function or not.""" return ( inspect.isfunction(object) @@ -316,7 +323,9 @@ def _is_benchmark(object): ) -def collect_benchmarks(module, benchmark_names): +def collect_benchmarks( + module: ModuleType, benchmark_names: list[str] +) -> list[Callable]: """Collect all benchmark functions from a module. Args: @@ -335,7 +344,7 @@ def collect_benchmarks(module, benchmark_names): @contextlib.contextmanager -def trace_memory_allocations(n_frames=1): +def trace_memory_allocations(n_frames: int = 1) -> Callable[[], tuple[int, int]]: """Context manager for tracing memory allocations in managed with block. Returns a thunk (zero-argument function) which can be called on exit from with block @@ -357,7 +366,9 @@ def trace_memory_allocations(n_frames=1): tracemalloc.stop() -def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry): +def _compile_jax_benchmark_and_analyse( + benchmark_function: Callable, results_entry: dict +) -> Callable: """Compile a JAX benchmark function and extract cost estimates if available.""" compiled_benchmark_function = jax.jit(benchmark_function).lower().compile() cost_analysis = compiled_benchmark_function.cost_analysis() @@ -385,12 +396,12 @@ def _compile_jax_benchmark_and_analyse(benchmark_function, results_entry): def run_benchmarks( - benchmarks, - number_runs, - number_repeats, - print_results=True, - parameter_overrides=None, -): + benchmarks: list[Callable], + number_runs: int, + number_repeats: int, + print_results: bool = True, + parameter_overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: """Run a set of benchmarks. Args: @@ -456,7 +467,7 @@ def run_benchmarks( return results -def get_system_info(): +def get_system_info() -> dict[str, Any]: """Get dictionary of metadata about system. Returns: @@ -482,7 +493,9 @@ def get_system_info(): } -def write_json_results_file(output_file, results, benchmark_module): +def write_json_results_file( + output_file: Path, results: dict[str, Any], benchmark_module: str +) -> None: """Write benchmark results and system information to a file in JSON format. Args: @@ -500,7 +513,7 @@ def write_json_results_file(output_file, results, benchmark_module): json.dump(output, f, indent=True) -def parse_args_collect_and_run_benchmarks(module=None): +def parse_args_collect_and_run_benchmarks(module: ModuleType | None = None) -> None: """Collect and run all benchmarks in a module and parse command line arguments. Args: From 38dfbf4c43c774fcdf5a58d9ca87fa6defd1e5cb Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 24 Jan 2025 09:17:09 +0000 Subject: [PATCH 16/22] Fix typo in docstring --- benchmarks/benchmarking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 628e7889..fa372df9 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -190,7 +190,7 @@ def benchmark( dictionary of values to pass to the benchmark as keyword arguments, corresponding to any precomputed values, the second entry optionally a reference value specifying the expected 'true' numerical output of the - behchmarked function to allow computing numerical error, or `None` if there + benchmarked function to allow computing numerical error, or `None` if there is no relevant reference value and third entry a boolean flag indicating whether to use JAX's just-in-time compilation transform to benchmark function. From ff0baf0a1b62f832e881d44d5f4d56d079765d88 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 24 Jan 2025 09:17:33 +0000 Subject: [PATCH 17/22] Make robust to change in cost_analysis return type in recent JAX versions --- benchmarks/benchmarking.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index fa372df9..1b5f0b2f 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -372,10 +372,12 @@ def _compile_jax_benchmark_and_analyse( """Compile a JAX benchmark function and extract cost estimates if available.""" compiled_benchmark_function = jax.jit(benchmark_function).lower().compile() cost_analysis = compiled_benchmark_function.cost_analysis() - if cost_analysis is not None and isinstance(cost_analysis, list): + if cost_analysis is not None: + if isinstance(cost_analysis, list): + cost_analysis = cost_analysis[0] results_entry["cost_analysis"] = { - "flops": cost_analysis[0].get("flops"), - "bytes_accessed": cost_analysis[0].get("bytes accessed"), + "flops": cost_analysis.get("flops"), + "bytes_accessed": cost_analysis.get("bytes accessed"), } memory_analysis = compiled_benchmark_function.memory_analysis() if memory_analysis is not None: From 6443203955a43cf7562c7fc4d4a976109a2128f5 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 27 Jan 2025 14:50:53 +0000 Subject: [PATCH 18/22] Correct key name for printing error --- benchmarks/benchmarking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 1b5f0b2f..85d282d6 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -238,7 +238,7 @@ def _format_results_entry(results_entry: dict) -> str: ) + ( f", max(abs(error)): {results_entry['max_abs_error']:>#7.2g}" - if "error" in results_entry + if "max_abs_error" in results_entry else "" ) + ( From c663e88366fa7a265153139e3851b5cb020d047f Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 27 Jan 2025 14:51:08 +0000 Subject: [PATCH 19/22] Record mean abs error as well as max --- benchmarks/benchmarking.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 85d282d6..3648cd54 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -451,6 +451,9 @@ def run_benchmarks( results_entry["max_abs_error"] = abs( reference_output - output ).max() + results_entry["mean_abs_error"] = abs( + reference_output - output + ).mean() run_times = [ time / number_runs for time in timeit.repeat( From b57cdc60232f50a5120c0e236319d90d8b5feaf2 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 28 Jan 2025 16:56:18 +0000 Subject: [PATCH 20/22] Ensure error values of float type --- benchmarks/benchmarking.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarking.py b/benchmarks/benchmarking.py index 3648cd54..bbf4a312 100644 --- a/benchmarks/benchmarking.py +++ b/benchmarks/benchmarking.py @@ -448,12 +448,12 @@ def run_benchmarks( results_entry["traced_memory_final_in_bytes"] = current_size results_entry["traced_memory_peak_in_bytes"] = peak_size if reference_output is not None and output is not None: - results_entry["max_abs_error"] = abs( - reference_output - output - ).max() - results_entry["mean_abs_error"] = abs( - reference_output - output - ).mean() + results_entry["max_abs_error"] = float( + abs(reference_output - output).max() + ) + results_entry["mean_abs_error"] = float( + abs(reference_output - output).mean() + ) run_times = [ time / number_runs for time in timeit.repeat( From c30d0b67018841d37b87486971f0f4179cefb825 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 29 Jan 2025 18:25:09 +0000 Subject: [PATCH 21/22] Add benchmark plotting module --- benchmarks/plotting.py | 209 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 benchmarks/plotting.py diff --git a/benchmarks/plotting.py b/benchmarks/plotting.py new file mode 100644 index 00000000..f13989b2 --- /dev/null +++ b/benchmarks/plotting.py @@ -0,0 +1,209 @@ +"""Utilities for plotting benchmark results.""" + +import argparse +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +def _set_axis_properties( + ax: plt.Axes, + parameter_values: np.ndarray, + parameter_label: str, + measurement_label: str, +) -> None: + ax.set( + xlabel=parameter_label, + ylabel=measurement_label, + xscale="log", + yscale="log", + xticks=parameter_values, + xticklabels=parameter_values, + ) + ax.minorticks_off() + + +def _plot_scaling_guide( + ax: plt.Axes, + parameter_symbol: str, + parameter_values: np.ndarray, + measurement_values: np.ndarray, + order: int, +) -> None: + n = np.argsort(parameter_values)[len(parameter_values) // 2] + coefficient = measurement_values[n] / parameter_values[n] ** order + ax.plot( + parameter_values, + coefficient * parameter_values**order, + "k:", + label=f"$\\mathcal{{O}}({parameter_symbol}^{order})$", + ) + + +def plot_times( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + min_times = np.array([min(r["run_times_in_seconds"]) for r in results]) + mid_times = np.array([np.median(r["run_times_in_seconds"]) for r in results]) + max_times = np.array([max(r["run_times_in_seconds"]) for r in results]) + ax.plot(parameter_values, mid_times, label="Measured") + ax.fill_between(parameter_values, min_times, max_times, alpha=0.5) + _plot_scaling_guide(ax, parameter_symbol, parameter_values, mid_times, 3) + ax.legend() + + +def plot_flops( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + flops = np.array([r["cost_analysis"]["flops"] for r in results]) + ax.plot(parameter_values, flops, label="Measured") + _plot_scaling_guide(ax, parameter_symbol, parameter_values, flops, 2) + ax.legend() + + +def plot_error( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + max_abs_errors = np.array([r["max_abs_error"] for r in results]) + mean_abs_errors = np.array([r["mean_abs_error"] for r in results]) + ax.plot(parameter_values, max_abs_errors, label="max(abs(error))") + ax.plot(parameter_values, mean_abs_errors, label="mean(abs(error))") + _plot_scaling_guide( + ax, + parameter_symbol, + parameter_values, + (max_abs_errors + mean_abs_errors) / 2, + 2, + ) + ax.legend() + + +def plot_memory( + ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict +) -> None: + bytes_accessed = np.array([r["cost_analysis"]["bytes_accessed"] for r in results]) + temp_size_in_bytes = np.array( + [r["memory_analysis"]["temp_size_in_bytes"] for r in results] + ) + output_size_in_bytes = np.array( + [r["memory_analysis"]["output_size_in_bytes"] for r in results] + ) + generated_code_size_in_bytes = np.array( + [r["memory_analysis"]["generated_code_size_in_bytes"] for r in results] + ) + ax.plot(parameter_values, bytes_accessed, label="Accesses") + ax.plot(parameter_values, temp_size_in_bytes, label="Temporary allocations") + ax.plot(parameter_values, output_size_in_bytes, label="Output size") + ax.plot(parameter_values, generated_code_size_in_bytes, label="Generated code size") + _plot_scaling_guide( + ax, + parameter_symbol, + parameter_values, + (bytes_accessed + output_size_in_bytes) / 2, + 2, + ) + ax.legend() + + +_measurement_plot_functions_and_labels = { + "times": (plot_times, "Run time / s"), + "flops": (plot_flops, "Floating point operations"), + "memory": (plot_memory, "Memory / B"), + "error": (plot_error, "Numerical error"), +} + + +def plot_results_against_bandlimit( + benchmark_results_path: str | Path, + functions: tuple[str] = ("forward", "inverse"), + measurements: tuple[str] = ("times", "flops", "memory", "error"), + axis_size: float = 3.0, + fig_dpi: int = 100, +) -> tuple[plt.Figure, plt.Axes]: + benchmark_results_path = Path(benchmark_results_path) + with benchmark_results_path.open("r") as f: + benchmark_results = json.load(f) + n_functions = len(functions) + n_measurements = len(measurements) + fig, axes = plt.subplots( + n_functions, + n_measurements, + figsize=(axis_size * n_measurements, axis_size * n_functions), + dpi=fig_dpi, + squeeze=False, + ) + for axes_row, function in zip(axes, functions): + results = benchmark_results["results"][function] + l_values = np.array([r["parameters"]["L"] for r in results]) + for ax, measurement in zip(axes_row, measurements): + plot_function, label = _measurement_plot_functions_and_labels[measurement] + try: + plot_function(ax, "L", l_values, results) + ax.set(title=function) + except KeyError: + ax.axis("off") + _set_axis_properties(ax, l_values, "Bandlimit $L$", label) + return fig, ax + + +def _parse_cli_arguments() -> argparse.Namespace: + """Parse rguments passed for plotting command line interface""" + parser = argparse.ArgumentParser( + description="Generate plot from benchmark results file.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-results-path", + type=Path, + help="Path to JSON file containing benchmark results to plot.", + ) + parser.add_argument( + "-output-path", + type=Path, + help="Path to write figure to.", + ) + parser.add_argument( + "-functions", + nargs="+", + help="Names of functions to plot. forward and inverse are plotted if omitted.", + ) + parser.add_argument( + "-measurements", + nargs="+", + help="Names of measurements to plot. All functions are plotted if omitted.", + ) + parser.add_argument( + "-axis-size", type=float, default=5.0, help="Size of each plot axis in inches." + ) + parser.add_argument( + "-dpi", type=int, default=100, help="Figure resolution in dots per inch." + ) + parser.add_argument( + "-title", type=str, help="Title for figure. No title added if omitted." + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_cli_arguments() + functions = ( + ("forward", "inverse") if args.functions is None else tuple(args.functions) + ) + measurements = ( + ("times", "flops", "memory", "error") + if args.measurements is None + else tuple(args.measurements) + ) + fig, _ = plot_results_against_bandlimit( + args.results_path, + functions=functions, + measurements=measurements, + axis_size=args.axis_size, + fig_dpi=args.dpi, + ) + if args.title is not None: + fig.suptitle(args.title) + fig.tight_layout() + fig.savefig(args.output_path) From f2562ab20dc4428d80744e90d2d3cdd224a8d17c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 3 Feb 2025 15:34:41 +0000 Subject: [PATCH 22/22] Force floats when plotting scaling guides --- benchmarks/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/plotting.py b/benchmarks/plotting.py index f13989b2..75c9aa33 100644 --- a/benchmarks/plotting.py +++ b/benchmarks/plotting.py @@ -33,10 +33,10 @@ def _plot_scaling_guide( order: int, ) -> None: n = np.argsort(parameter_values)[len(parameter_values) // 2] - coefficient = measurement_values[n] / parameter_values[n] ** order + coefficient = measurement_values[n] / float(parameter_values[n]) ** order ax.plot( parameter_values, - coefficient * parameter_values**order, + coefficient * parameter_values.astype(float) ** order, "k:", label=f"$\\mathcal{{O}}({parameter_symbol}^{order})$", )