Skip to content

Commit 3969110

Browse files
authored
Various fixes for the check_shapes benchmarks. (#33)
* Fixed bug where code actually was run with `check_shapes` which shouldn't have been. * Removed `check_shapes` imports in the `without` code. * Store the `with` and `without` temporary files at separate paths. * Refactored some statistics code into a separate module. * Print estimated overhead when running `run_benchmark`. * Added plot of absolute overhead.
1 parent 5f6c187 commit 3969110

File tree

6 files changed

+137
-81
lines changed

6 files changed

+137
-81
lines changed

.github/workflows/benchmark.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
env:
1010
penv: './poetryenv poetryenvs'
11-
run: 'run python benchmark'
11+
run: 'run python -m benchmark'
1212
bex: 'benchmark/examples'
1313
res: 'gh-pages/benchmark_results'
1414
plots: 'gh-pages/docs/benchmark_plots/'
@@ -35,15 +35,15 @@ jobs:
3535
${penv}/torch_max install
3636
- name: Run benchmarks
3737
run: |
38-
${penv}/np_max ${run}/run_benchmark.py ${bex}/np_example.py ${res}
39-
${penv}/tf_max ${run}/run_benchmark.py ${bex}/tf_example.py ${res}
40-
${penv}/tf_max ${run}/run_benchmark.py ${bex}/tf_example.py --modifiers=no_compile ${res}
41-
${penv}/jax_max ${run}/run_benchmark.py ${bex}/jax_example.py ${res}
42-
${penv}/jax_max ${run}/run_benchmark.py ${bex}/jax_example.py --modifiers=no_jit ${res}
43-
${penv}/torch_max ${run}/run_benchmark.py ${bex}/torch_example.py ${res}
38+
${penv}/np_max ${run}.run_benchmark ${bex}/np_example.py ${res}
39+
${penv}/tf_max ${run}.run_benchmark ${bex}/tf_example.py ${res}
40+
${penv}/tf_max ${run}.run_benchmark ${bex}/tf_example.py --modifiers=no_compile ${res}
41+
${penv}/jax_max ${run}.run_benchmark ${bex}/jax_example.py ${res}
42+
${penv}/jax_max ${run}.run_benchmark ${bex}/jax_example.py --modifiers=no_jit ${res}
43+
${penv}/torch_max ${run}.run_benchmark ${bex}/torch_example.py ${res}
4444
- name: Plot benchmarks
4545
run: |
46-
${penv}/np_max ${run}/plot_benchmarks.py ${res}
46+
${penv}/np_max ${run}.plot_benchmarks ${res}
4747
mkdir -p ${plots}
4848
mv ${res}/overhead.png ${plots}
4949
- name: Commit new benchmark results

benchmark/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Most recent results are shown in our
99
To run a benchmark use:
1010

1111
```bash
12-
python benchmark/run_benchmark.py \
12+
python -m benchmark.run_benchmark \
1313
<path to example script> \
1414
[--modifiers=<other modification to the script>] \
1515
<output_directory>
@@ -18,7 +18,7 @@ python benchmark/run_benchmark.py \
1818
Then plot the results with:
1919

2020
```bash
21-
python benchmark/plot_benchmarks.py <output_directory>
21+
python -m benchmark.plot_benchmarks <output_directory>
2222
```
2323

2424
The plotter will plot all results found in the output directory, so your can run `run_benchmark.py`
@@ -30,13 +30,13 @@ poetry install
3030
./poetryenv -r poetryenvs install
3131

3232
# Run all benchmarks:
33-
./poetryenv poetryenvs/np_max run python benchmark/run_benchmark.py benchmark/examples/np_example.py benchmark_results
34-
./poetryenv poetryenvs/tf_max run python benchmark/run_benchmark.py benchmark/examples/tf_example.py benchmark_results
35-
./poetryenv poetryenvs/tf_max run python benchmark/run_benchmark.py benchmark/examples/tf_example.py --modifiers=no_compile benchmark_results
36-
./poetryenv poetryenvs/jax_max run python benchmark/run_benchmark.py benchmark/examples/jax_example.py benchmark_results
37-
./poetryenv poetryenvs/jax_max run python benchmark/run_benchmark.py benchmark/examples/jax_example.py --modifiers=no_jit benchmark_results
38-
./poetryenv poetryenvs/torch_max run python benchmark/run_benchmark.py benchmark/examples/torch_example.py benchmark_results
33+
./poetryenv poetryenvs/np_max run python -m benchmark.run_benchmark benchmark/examples/np_example.py benchmark_results
34+
./poetryenv poetryenvs/tf_max run python -m benchmark.run_benchmark benchmark/examples/tf_example.py benchmark_results
35+
./poetryenv poetryenvs/tf_max run python -m benchmark.run_benchmark benchmark/examples/tf_example.py --modifiers=no_compile benchmark_results
36+
./poetryenv poetryenvs/jax_max run python -m benchmark.run_benchmark benchmark/examples/jax_example.py benchmark_results
37+
./poetryenv poetryenvs/jax_max run python -m benchmark.run_benchmark benchmark/examples/jax_example.py --modifiers=no_jit benchmark_results
38+
./poetryenv poetryenvs/torch_max run python -m benchmark.run_benchmark benchmark/examples/torch_example.py benchmark_results
3939

4040
# Plot results:
41-
poetry run python benchmark/plot_benchmarks.py benchmark_results
41+
poetry run python -m benchmark.plot_benchmarks benchmark_results
4242
```

benchmark/plot_benchmarks.py

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,23 @@
1313
# limitations under the License.
1414
import argparse
1515
from pathlib import Path
16+
from typing import Any, Optional
1617

1718
import matplotlib.pyplot as plt
1819
import numpy as np
1920
import pandas as pd
21+
from matplotlib.axes import Axes
22+
23+
from .stats import Stats
24+
25+
NDArray = Any
2026

2127

2228
def plot(output_dir: Path) -> None:
2329
result_dfs = [pd.read_csv(f) for f in output_dir.glob("results_*.csv")]
2430
results_df = pd.concat(result_dfs, axis="index", ignore_index=True)
2531

26-
n_columns = 2
32+
n_columns = 3
2733
n_rows = len(results_df.name.unique())
2834
width = 6 * n_columns
2935
height = 4 * n_rows
@@ -37,74 +43,46 @@ def plot(output_dir: Path) -> None:
3743

3844
for i, (ax_name, ax_df) in enumerate(results_df.groupby("name")):
3945
line_xs = []
40-
line_y_with_means = []
41-
line_y_with_uppers = []
42-
line_y_with_lowers = []
43-
line_y_without_means = []
44-
line_y_without_uppers = []
45-
line_y_without_lowers = []
46-
line_y_overhead_means = []
47-
line_y_overhead_uppers = []
48-
line_y_overhead_lowers = []
46+
line_ys = []
4947

5048
for timestamp, timestamp_df in ax_df.groupby("timestamp"):
51-
by_cs = timestamp_df.groupby("check_shapes")
52-
mean_by_cs = by_cs.time_s.mean()
53-
std_by_cs = by_cs.time_s.std().fillna(0.0)
54-
var_by_cs = by_cs.time_s.var().fillna(0.0)
55-
56-
with_mean = mean_by_cs[True]
57-
with_mean_sq = with_mean ** 2
58-
with_std = std_by_cs[True]
59-
with_var = var_by_cs[True]
60-
without_mean = mean_by_cs[False]
61-
without_mean_sq = without_mean ** 2
62-
without_std = std_by_cs[False]
63-
without_var = var_by_cs[False]
64-
65-
overhead_mean = (with_mean / without_mean) - 1
66-
# https://en.wikipedia.org/wiki/Ratio_distribution#Uncorrelated_noncentral_normal_ratio
67-
overhead_var = (with_mean_sq / without_mean_sq) * (
68-
(with_var / with_mean_sq) + (without_var / without_mean_sq)
69-
)
70-
overhead_std = np.sqrt(overhead_var)
71-
7249
line_xs.append(timestamp)
73-
line_y_with_means.append(with_mean)
74-
line_y_with_uppers.append(with_mean + 1.96 * with_std)
75-
line_y_with_lowers.append(with_mean - 1.96 * with_std)
76-
line_y_without_means.append(without_mean)
77-
line_y_without_uppers.append(without_mean + 1.96 * without_std)
78-
line_y_without_lowers.append(without_mean - 1.96 * without_std)
79-
line_y_overhead_means.append(100 * overhead_mean)
80-
line_y_overhead_uppers.append(100 * (overhead_mean + 1.96 * overhead_std))
81-
line_y_overhead_lowers.append(100 * (overhead_mean - 1.96 * overhead_std))
50+
line_ys.append(Stats.new(timestamp_df))
51+
52+
def plot_mean_and_std(
53+
ax: Axes, prefix: str, *, label: Optional[str] = None, scale: float = 1.0
54+
) -> None:
55+
mean_name = f"{prefix}_mean"
56+
std_name = f"{prefix}_std"
57+
58+
# pylint: disable=cell-var-from-loop
59+
mean: NDArray = np.array([getattr(y, mean_name) for y in line_ys]) * scale
60+
std: NDArray = np.array([getattr(y, std_name) for y in line_ys]) * scale
61+
lower: NDArray = mean - 1.96 * std
62+
upper: NDArray = mean + 1.96 * std
63+
64+
(mean_line,) = ax.plot(line_xs, mean, label=label)
65+
color = mean_line.get_color()
66+
ax.fill_between(line_xs, lower, upper, color=color, alpha=0.3)
67+
68+
ax.set_title(ax_name)
69+
ax.tick_params(axis="x", labelrotation=30)
70+
if np.min(lower) > 0:
71+
ax.set_ylim(bottom=0.0)
8272

8373
ax = axes[i][0]
84-
(mean_line,) = ax.plot(line_xs, line_y_with_means, label="with check_shapes")
85-
color = mean_line.get_color()
86-
ax.fill_between(line_xs, line_y_with_lowers, line_y_with_uppers, color=color, alpha=0.3)
87-
(mean_line,) = ax.plot(line_xs, line_y_without_means, label="without check_shapes")
88-
color = mean_line.get_color()
89-
ax.fill_between(
90-
line_xs, line_y_without_lowers, line_y_without_uppers, color=color, alpha=0.3
91-
)
92-
ax.set_title(ax_name)
74+
plot_mean_and_std(ax, "with", label="with check_shapes")
75+
plot_mean_and_std(ax, "without", label="without check_shapes")
9376
ax.set_ylabel("time / s")
94-
ax.tick_params(axis="x", labelrotation=30)
9577
ax.legend()
9678

9779
ax = axes[i][1]
98-
(mean_line,) = ax.plot(line_xs, line_y_overhead_means)
99-
color = mean_line.get_color()
100-
ax.fill_between(
101-
line_xs, line_y_overhead_lowers, line_y_overhead_uppers, color=color, alpha=0.3
102-
)
103-
ax.set_title(ax_name)
80+
plot_mean_and_std(ax, "abs_overhead")
81+
ax.set_ylabel("overhead / s")
82+
83+
ax = axes[i][2]
84+
plot_mean_and_std(ax, "rel_overhead", scale=100.0)
10485
ax.set_ylabel("% overhead")
105-
if np.min(line_y_overhead_lowers) >= 0:
106-
ax.set_ylim(bottom=0.0)
107-
ax.tick_params(axis="x", labelrotation=30)
10886

10987
fig.tight_layout()
11088
fig.savefig(output_dir / "overhead.png")

benchmark/run_benchmark.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import pandas as pd
2323

24+
from .stats import Stats
25+
2426
TIMESTAMP_FORMAT = "%Y%m%d_%H%M%S.%f"
2527

2628

@@ -42,6 +44,8 @@ class Modifier(NamedTuple):
4244
Modifier(r"@inherit_check_shapes", ""),
4345
Modifier(r"@check_shapes\(.*?\)", ""),
4446
Modifier(r"cs\((.*?), \".*?\"\)", r"\1"),
47+
Modifier(r"from check_shapes import \(.*?\)", ""),
48+
Modifier(r"from check_shapes[^(]*?^", ""),
4549
)
4650

4751

@@ -52,12 +56,12 @@ class Modifier(NamedTuple):
5256

5357

5458
def run_modified_script(
55-
script: Path, modifiers: Modifiers, reps: int, keep: bool, output_dir: Path
59+
script: Path, tmp_name: str, modifiers: Modifiers, reps: int, keep: bool, output_dir: Path
5660
) -> Sequence[float]:
57-
modified = output_dir / "tmp.py"
61+
modified = output_dir / f"{tmp_name}.py"
5862
src = script.read_text()
5963
for modifier in modifiers:
60-
src = re.sub(modifier.pattern, modifier.repl, src)
64+
src = re.sub(modifier.pattern, modifier.repl, src, flags=re.MULTILINE + re.DOTALL)
6165
modified.write_text(src)
6266

6367
timings = []
@@ -94,7 +98,7 @@ def run_benchmark(
9498
}
9599

96100
modifiers = tuple(m for ms in modifier_strs for m in _MODIFIERS[ms])
97-
with_timings = run_modified_script(script, modifiers, reps, keep, output_dir)
101+
with_timings = run_modified_script(script, "with", modifiers, reps, keep, output_dir)
98102
with_df = pd.DataFrame(
99103
{
100104
**shared_data,
@@ -104,7 +108,7 @@ def run_benchmark(
104108
)
105109

106110
modifiers = _CHECK_SHAPES_MODIFIER + modifiers
107-
without_timings = run_modified_script(script, modifiers, reps, keep, output_dir)
111+
without_timings = run_modified_script(script, "without", modifiers, reps, keep, output_dir)
108112
without_df = pd.DataFrame(
109113
{
110114
**shared_data,
@@ -117,6 +121,10 @@ def run_benchmark(
117121
csv_path = output_dir / f"results_{name}_{timestamp_str}.csv"
118122
df.to_csv(csv_path, index=False)
119123

124+
stats = Stats.new(df)
125+
print(f"Relative overhead: {stats.rel_overhead_mean:.2%} +/- {stats.rel_overhead_std:.2%}")
126+
print(f"Absolute overhead: {stats.abs_overhead_mean:.2}s +/- {stats.abs_overhead_std:.2}s")
127+
120128

121129
def main() -> None:
122130
parser = argparse.ArgumentParser(description="Modify a script, then times its execution.")

benchmark/stats.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2022 The GPflow Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from dataclasses import dataclass
15+
16+
import numpy as np
17+
import pandas as pd
18+
19+
20+
@dataclass
21+
class Stats:
22+
23+
with_mean: float
24+
with_std: float
25+
without_mean: float
26+
without_std: float
27+
rel_overhead_mean: float
28+
rel_overhead_std: float
29+
abs_overhead_mean: float
30+
abs_overhead_std: float
31+
32+
@staticmethod
33+
def new(df: pd.DataFrame) -> "Stats":
34+
by_cs = df.groupby("check_shapes")
35+
mean_by_cs = by_cs.time_s.mean()
36+
std_by_cs = by_cs.time_s.std().fillna(0.0)
37+
var_by_cs = by_cs.time_s.var().fillna(0.0)
38+
39+
with_mean = mean_by_cs[True]
40+
with_mean_sq = with_mean ** 2
41+
with_std = std_by_cs[True]
42+
with_var = var_by_cs[True]
43+
44+
without_mean = mean_by_cs[False]
45+
without_mean_sq = without_mean ** 2
46+
without_std = std_by_cs[False]
47+
without_var = var_by_cs[False]
48+
49+
rel_overhead_mean = (with_mean / without_mean) - 1
50+
# https://en.wikipedia.org/wiki/Ratio_distribution#Uncorrelated_noncentral_normal_ratio
51+
rel_overhead_var = (with_mean_sq / without_mean_sq) * (
52+
(with_var / with_mean_sq) + (without_var / without_mean_sq)
53+
)
54+
rel_overhead_std = np.sqrt(rel_overhead_var)
55+
56+
abs_overhead_mean = with_mean - without_mean
57+
abs_overhead_var = with_var + without_var
58+
abs_overhead_std = np.sqrt(abs_overhead_var)
59+
60+
return Stats(
61+
with_mean=with_mean,
62+
with_std=with_std,
63+
without_mean=without_mean,
64+
without_std=without_std,
65+
rel_overhead_mean=rel_overhead_mean,
66+
rel_overhead_std=rel_overhead_std,
67+
abs_overhead_mean=abs_overhead_mean,
68+
abs_overhead_std=abs_overhead_std,
69+
)

pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ disable=comparison-with-callable,
8585
too-few-public-methods,
8686
too-many-arguments,
8787
too-many-branches,
88+
too-many-instance-attributes,
8889
too-many-lines,
8990
too-many-locals,
9091
too-many-public-methods,

0 commit comments

Comments
 (0)