Skip to content

Commit d331d7e

Browse files
committed
benchmark: changed how plotting works
1 parent 1182d12 commit d331d7e

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

model/distributions/sphere/watson/benchmark_fib_starts.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
PYTHONPATH=$PWD poetry run python model/distributions/sphere/watson/benchmark_fib_starts.py
66
'''
77

8+
import json
9+
from pathlib import Path
10+
import plotly.express as px
811
from model.distributions.sphere.watson.fibonachi import WatsonFibonachiSampling
912
from util.selectors.slider_float import FloatSlider
1013
import pyperf
@@ -30,7 +33,7 @@ def bench_single_kappa(kappa, sample_count, id):
3033
results = {}
3134
for method_name, method in methods.items():
3235
bench_name = f"Watson Fibonacci Sampling: {method_name} (kappa={kappa}) [{id}]"
33-
benchmark = runner.bench_func(bench_name, benchmark_kappa, method, kappa, sample_count, 5)
36+
benchmark = runner.bench_func(bench_name, benchmark_kappa, method, kappa, sample_count, 1)
3437
results[method_name] = benchmark
3538

3639
return results
@@ -68,29 +71,86 @@ def bench_multiple_sample_counts_log(kappa):
6871
return all_results
6972

7073

71-
def plot_benches(results, title, filename, x_label, log_x=False, log_y=False):
72-
import plotly.express as px
73-
try:
74-
if x_label == "sample_count":
75-
rows = [dict(name=n, sample_count=k, time=t.mean()) for n, pts in results.items() for k, t in pts]
76-
else:
77-
rows = [dict(name=n, kappa=k, time=t.mean()) for n, pts in results.items() for k, t in pts]
78-
fig = px.line(
79-
rows,
80-
x=x_label,
81-
y="time",
82-
color="name",
83-
markers=True,
84-
title=title,
85-
log_x=log_x,
86-
log_y=log_y,
74+
def _sanitize_filename(name):
75+
return name.replace(" ", "_").replace(":", "")
76+
77+
def _rows_from_results(results, x_label):
78+
if x_label == "sample_count":
79+
return [dict(name=n, sample_count=k, time=t.mean()) for n, pts in results.items() for k, t in pts]
80+
if x_label == "kappa":
81+
return [dict(name=n, kappa=k, time=t.mean()) for n, pts in results.items() for k, t in pts]
82+
raise ValueError(f"Unsupported x_label: {x_label}")
83+
84+
def _plot_rows(rows, title, filename, x_label, log_x=False, log_y=False):
85+
fig = px.line(
86+
rows,
87+
x=x_label,
88+
y="time",
89+
color="name",
90+
markers=True,
91+
log_x=log_x,
92+
log_y=log_y,
93+
)
94+
fig.update_layout(
95+
legend=dict(
96+
orientation="h",
97+
yanchor="bottom",
98+
y=1.02,
99+
xanchor="left",
100+
x=0,
87101
)
88-
fig.write_image(f"{filename.replace(' ', '_').replace(':', '')}.svg")
102+
)
103+
try:
104+
fig.write_image(f"{_sanitize_filename(filename)}.svg")
89105
except Exception as e:
90106
print("Generating plot failed, dumping data:", e)
91-
print(results.items())
107+
print(rows)
92108
print("Trying to save html as fallback")
93-
fig.write_html(f"{title.replace(' ', '_').replace(':', '')}.html", include_plotlyjs="cdn", full_html=True)
109+
fig.write_html(f"{_sanitize_filename(title)}.html", include_plotlyjs="cdn", full_html=True)
110+
111+
def plot_benches(results, title=None, filename=None, x_label=None, log_x=None, log_y=None, json_filename=None):
112+
113+
if isinstance(results, (str, Path)):
114+
json_path = Path(results)
115+
with json_path.open("r", encoding="utf-8") as handle:
116+
payload = json.load(handle)
117+
rows = payload["rows"]
118+
if title is None:
119+
title = payload.get("title", json_path.stem)
120+
if filename is None:
121+
filename = payload.get("filename", json_path.stem)
122+
if x_label is None:
123+
x_label = payload.get("x_label")
124+
if log_x is None:
125+
log_x = payload.get("log_x", False)
126+
if log_y is None:
127+
log_y = payload.get("log_y", False)
128+
if x_label is None:
129+
raise ValueError("x_label is required when replotting from JSON")
130+
_plot_rows(rows, title, filename, x_label, log_x=log_x, log_y=log_y)
131+
return
132+
133+
if title is None or filename is None or x_label is None:
134+
raise ValueError("title, filename, and x_label are required for raw benchmark data")
135+
if log_x is None:
136+
log_x = False
137+
if log_y is None:
138+
log_y = False
139+
140+
rows = _rows_from_results(results, x_label)
141+
payload = {
142+
"title": title,
143+
"filename": filename,
144+
"x_label": x_label,
145+
"log_x": log_x,
146+
"log_y": log_y,
147+
"rows": rows,
148+
}
149+
if json_filename is None:
150+
json_filename = f"{_sanitize_filename(filename)}.json"
151+
with open(json_filename, "w", encoding="utf-8") as handle:
152+
json.dump(payload, handle, indent=2, sort_keys=True)
153+
_plot_rows(rows, title, filename, x_label, log_x=log_x, log_y=log_y)
94154

95155

96156

@@ -114,4 +174,3 @@ def plot_benches(results, title, filename, x_label, log_x=False, log_y=False):
114174

115175
#plot_benches(mult_samples_neg_10, "time taken for various sample counts (kappa=-10)", "time taken for various sample counts (kappa=-10)", "sample_count")
116176
plot_benches(log_mult_samples_neg_10, "time taken for various sample counts (kappa=-10)", "time taken for various sample counts log scale (kappa=-10)", "sample_count", log_x=True, log_y=True)
117-

0 commit comments

Comments
 (0)