Skip to content

Commit c30d0b6

Browse files
committed
Add benchmark plotting module
1 parent b57cdc6 commit c30d0b6

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed

benchmarks/plotting.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Utilities for plotting benchmark results."""
2+
3+
import argparse
4+
import json
5+
from pathlib import Path
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
11+
def _set_axis_properties(
12+
ax: plt.Axes,
13+
parameter_values: np.ndarray,
14+
parameter_label: str,
15+
measurement_label: str,
16+
) -> None:
17+
ax.set(
18+
xlabel=parameter_label,
19+
ylabel=measurement_label,
20+
xscale="log",
21+
yscale="log",
22+
xticks=parameter_values,
23+
xticklabels=parameter_values,
24+
)
25+
ax.minorticks_off()
26+
27+
28+
def _plot_scaling_guide(
29+
ax: plt.Axes,
30+
parameter_symbol: str,
31+
parameter_values: np.ndarray,
32+
measurement_values: np.ndarray,
33+
order: int,
34+
) -> None:
35+
n = np.argsort(parameter_values)[len(parameter_values) // 2]
36+
coefficient = measurement_values[n] / parameter_values[n] ** order
37+
ax.plot(
38+
parameter_values,
39+
coefficient * parameter_values**order,
40+
"k:",
41+
label=f"$\\mathcal{{O}}({parameter_symbol}^{order})$",
42+
)
43+
44+
45+
def plot_times(
46+
ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict
47+
) -> None:
48+
min_times = np.array([min(r["run_times_in_seconds"]) for r in results])
49+
mid_times = np.array([np.median(r["run_times_in_seconds"]) for r in results])
50+
max_times = np.array([max(r["run_times_in_seconds"]) for r in results])
51+
ax.plot(parameter_values, mid_times, label="Measured")
52+
ax.fill_between(parameter_values, min_times, max_times, alpha=0.5)
53+
_plot_scaling_guide(ax, parameter_symbol, parameter_values, mid_times, 3)
54+
ax.legend()
55+
56+
57+
def plot_flops(
58+
ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict
59+
) -> None:
60+
flops = np.array([r["cost_analysis"]["flops"] for r in results])
61+
ax.plot(parameter_values, flops, label="Measured")
62+
_plot_scaling_guide(ax, parameter_symbol, parameter_values, flops, 2)
63+
ax.legend()
64+
65+
66+
def plot_error(
67+
ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict
68+
) -> None:
69+
max_abs_errors = np.array([r["max_abs_error"] for r in results])
70+
mean_abs_errors = np.array([r["mean_abs_error"] for r in results])
71+
ax.plot(parameter_values, max_abs_errors, label="max(abs(error))")
72+
ax.plot(parameter_values, mean_abs_errors, label="mean(abs(error))")
73+
_plot_scaling_guide(
74+
ax,
75+
parameter_symbol,
76+
parameter_values,
77+
(max_abs_errors + mean_abs_errors) / 2,
78+
2,
79+
)
80+
ax.legend()
81+
82+
83+
def plot_memory(
84+
ax: plt.Axes, parameter_symbol: str, parameter_values: np.ndarray, results: dict
85+
) -> None:
86+
bytes_accessed = np.array([r["cost_analysis"]["bytes_accessed"] for r in results])
87+
temp_size_in_bytes = np.array(
88+
[r["memory_analysis"]["temp_size_in_bytes"] for r in results]
89+
)
90+
output_size_in_bytes = np.array(
91+
[r["memory_analysis"]["output_size_in_bytes"] for r in results]
92+
)
93+
generated_code_size_in_bytes = np.array(
94+
[r["memory_analysis"]["generated_code_size_in_bytes"] for r in results]
95+
)
96+
ax.plot(parameter_values, bytes_accessed, label="Accesses")
97+
ax.plot(parameter_values, temp_size_in_bytes, label="Temporary allocations")
98+
ax.plot(parameter_values, output_size_in_bytes, label="Output size")
99+
ax.plot(parameter_values, generated_code_size_in_bytes, label="Generated code size")
100+
_plot_scaling_guide(
101+
ax,
102+
parameter_symbol,
103+
parameter_values,
104+
(bytes_accessed + output_size_in_bytes) / 2,
105+
2,
106+
)
107+
ax.legend()
108+
109+
110+
_measurement_plot_functions_and_labels = {
111+
"times": (plot_times, "Run time / s"),
112+
"flops": (plot_flops, "Floating point operations"),
113+
"memory": (plot_memory, "Memory / B"),
114+
"error": (plot_error, "Numerical error"),
115+
}
116+
117+
118+
def plot_results_against_bandlimit(
119+
benchmark_results_path: str | Path,
120+
functions: tuple[str] = ("forward", "inverse"),
121+
measurements: tuple[str] = ("times", "flops", "memory", "error"),
122+
axis_size: float = 3.0,
123+
fig_dpi: int = 100,
124+
) -> tuple[plt.Figure, plt.Axes]:
125+
benchmark_results_path = Path(benchmark_results_path)
126+
with benchmark_results_path.open("r") as f:
127+
benchmark_results = json.load(f)
128+
n_functions = len(functions)
129+
n_measurements = len(measurements)
130+
fig, axes = plt.subplots(
131+
n_functions,
132+
n_measurements,
133+
figsize=(axis_size * n_measurements, axis_size * n_functions),
134+
dpi=fig_dpi,
135+
squeeze=False,
136+
)
137+
for axes_row, function in zip(axes, functions):
138+
results = benchmark_results["results"][function]
139+
l_values = np.array([r["parameters"]["L"] for r in results])
140+
for ax, measurement in zip(axes_row, measurements):
141+
plot_function, label = _measurement_plot_functions_and_labels[measurement]
142+
try:
143+
plot_function(ax, "L", l_values, results)
144+
ax.set(title=function)
145+
except KeyError:
146+
ax.axis("off")
147+
_set_axis_properties(ax, l_values, "Bandlimit $L$", label)
148+
return fig, ax
149+
150+
151+
def _parse_cli_arguments() -> argparse.Namespace:
152+
"""Parse rguments passed for plotting command line interface"""
153+
parser = argparse.ArgumentParser(
154+
description="Generate plot from benchmark results file.",
155+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
156+
)
157+
parser.add_argument(
158+
"-results-path",
159+
type=Path,
160+
help="Path to JSON file containing benchmark results to plot.",
161+
)
162+
parser.add_argument(
163+
"-output-path",
164+
type=Path,
165+
help="Path to write figure to.",
166+
)
167+
parser.add_argument(
168+
"-functions",
169+
nargs="+",
170+
help="Names of functions to plot. forward and inverse are plotted if omitted.",
171+
)
172+
parser.add_argument(
173+
"-measurements",
174+
nargs="+",
175+
help="Names of measurements to plot. All functions are plotted if omitted.",
176+
)
177+
parser.add_argument(
178+
"-axis-size", type=float, default=5.0, help="Size of each plot axis in inches."
179+
)
180+
parser.add_argument(
181+
"-dpi", type=int, default=100, help="Figure resolution in dots per inch."
182+
)
183+
parser.add_argument(
184+
"-title", type=str, help="Title for figure. No title added if omitted."
185+
)
186+
return parser.parse_args()
187+
188+
189+
if __name__ == "__main__":
190+
args = _parse_cli_arguments()
191+
functions = (
192+
("forward", "inverse") if args.functions is None else tuple(args.functions)
193+
)
194+
measurements = (
195+
("times", "flops", "memory", "error")
196+
if args.measurements is None
197+
else tuple(args.measurements)
198+
)
199+
fig, _ = plot_results_against_bandlimit(
200+
args.results_path,
201+
functions=functions,
202+
measurements=measurements,
203+
axis_size=args.axis_size,
204+
fig_dpi=args.dpi,
205+
)
206+
if args.title is not None:
207+
fig.suptitle(args.title)
208+
fig.tight_layout()
209+
fig.savefig(args.output_path)

0 commit comments

Comments
 (0)