1010from dataclasses import dataclass
1111import matplotlib .dates as mdates
1212from benches .result import BenchmarkRun , Result
13+ import numpy as np
1314
1415@dataclass
1516class BenchmarkMetadata :
@@ -23,11 +24,14 @@ class BenchmarkSeries:
2324 runs : list [BenchmarkRun ]
2425
2526@dataclass
26- class BenchmarkTimeSeries :
27+ class BenchmarkChart :
2728 label : str
2829 html : str
2930
30- def create_time_series_chart (benchmarks : list [BenchmarkSeries ], github_repo : str ) -> list [BenchmarkTimeSeries ]:
31+ def tooltip_css () -> str :
32+ return '.mpld3-tooltip{background:white;padding:8px;border:1px solid #ddd;border-radius:4px;font-family:monospace;white-space:pre;}'
33+
34+ def create_time_series_chart (benchmarks : list [BenchmarkSeries ], github_repo : str ) -> list [BenchmarkChart ]:
3135 plt .close ('all' )
3236
3337 num_benchmarks = len (benchmarks )
@@ -66,7 +70,7 @@ def create_time_series_chart(benchmarks: list[BenchmarkSeries], github_repo: str
6670 for point in sorted_points ]
6771
6872 tooltip = mpld3 .plugins .PointHTMLTooltip (scatter , tooltip_labels ,
69- css = '.mpld3-tooltip{background:white;padding:8px;border:1px solid #ddd;border-radius:4px;font-family:monospace;white-space:pre;}' ,
73+ css = tooltip_css () ,
7074 targets = targets )
7175 mpld3 .plugins .connect (fig , tooltip )
7276
@@ -94,7 +98,104 @@ def create_time_series_chart(benchmarks: list[BenchmarkSeries], github_repo: str
9498 ax .xaxis .set_major_formatter (mdates .ConciseDateFormatter ('%Y-%m-%d %H:%M:%S' ))
9599
96100 plt .tight_layout ()
97- html_charts .append (BenchmarkTimeSeries (html = mpld3 .fig_to_html (fig ), label = benchmark .label ))
101+ html_charts .append (BenchmarkChart (html = mpld3 .fig_to_html (fig ), label = benchmark .label ))
102+ plt .close (fig )
103+
104+ return html_charts
105+
106+ @dataclass
107+ class ExplicitGroup :
108+ name : str
109+ nnames : int
110+ metadata : BenchmarkMetadata
111+ runs : dict [str , dict [str , Result ]]
112+
113+ def create_explicit_groups (benchmark_runs : list [BenchmarkRun ], compare_names : list [str ]) -> list [ExplicitGroup ]:
114+ groups = {}
115+
116+ for run in benchmark_runs :
117+ if run .name in compare_names :
118+ for res in run .results :
119+ if res .explicit_group != '' :
120+ if res .explicit_group not in groups :
121+ groups [res .explicit_group ] = ExplicitGroup (name = res .explicit_group , nnames = len (compare_names ),
122+ metadata = BenchmarkMetadata (unit = res .unit , lower_is_better = res .lower_is_better ),
123+ runs = {})
124+
125+ group = groups [res .explicit_group ]
126+ if res .label not in group .runs :
127+ group .runs [res .label ] = {name : None for name in compare_names }
128+
129+ if group .runs [res .label ][run .name ] is None :
130+ group .runs [res .label ][run .name ] = res
131+
132+ return list (groups .values ())
133+
134+ def create_grouped_bar_charts (groups : list [ExplicitGroup ]) -> list [BenchmarkChart ]:
135+ plt .close ('all' )
136+
137+ html_charts = []
138+
139+ for group in groups :
140+ fig , ax = plt .subplots (figsize = (10 , 6 ))
141+
142+ x = np .arange (group .nnames )
143+ x_labels = []
144+ width = 0.8 / len (group .runs )
145+
146+ max_height = 0
147+
148+ for i , (run_name , run_results ) in enumerate (group .runs .items ()):
149+ offset = width * i
150+
151+ positions = x + offset
152+ x_labels = run_results .keys ()
153+ valid_data = [r .value if r is not None else 0 for r in run_results .values ()]
154+ rects = ax .bar (positions , valid_data , width , label = run_name )
155+ # This is a hack to disable all bar_label. Setting labels to empty doesn't work.
156+ # We create our own labels below for each bar, this works better in mpld3.
157+ ax .bar_label (rects , fmt = '' )
158+
159+ for rect , run , res in zip (rects , run_results .keys (), run_results .values ()):
160+ height = rect .get_height ()
161+ if height > max_height :
162+ max_height = height
163+
164+ ax .text (rect .get_x () + rect .get_width ()/ 2. , height + 2 ,
165+ f'{ res .value :.1f} ' ,
166+ ha = 'center' , va = 'bottom' , fontsize = 9 )
167+
168+ tooltip_labels = [
169+ f"Run: { run } \n "
170+ f"Label: { res .label } \n "
171+ f"Value: { res .value :.2f} { res .unit } \n "
172+ ]
173+ tooltip = mpld3 .plugins .LineHTMLTooltip (rect , tooltip_labels , css = tooltip_css ())
174+ mpld3 .plugins .connect (ax .figure , tooltip )
175+
176+ ax .set_xticks ([])
177+ ax .grid (True , axis = 'y' , alpha = 0.2 )
178+ ax .set_ylabel (f"Value ({ group .metadata .unit } )" )
179+ ax .legend (loc = 'upper left' )
180+ ax .set_title (group .name , pad = 20 )
181+ performance_indicator = "lower is better" if group .metadata .lower_is_better else "higher is better"
182+ ax .text (0.5 , 1.03 , f"({ performance_indicator } )" ,
183+ ha = 'center' ,
184+ transform = ax .transAxes ,
185+ style = 'italic' ,
186+ fontsize = 7 ,
187+ color = '#666666' )
188+
189+ for idx , label in enumerate (x_labels ):
190+ # this is a hack to get labels to show above the legend
191+ # we normalize the idx to transAxes transform and offset it a little.
192+ x_norm = (idx + 0.3 - ax .get_xlim ()[0 ]) / (ax .get_xlim ()[1 ] - ax .get_xlim ()[0 ])
193+ ax .text (x_norm , 1.00 , label ,
194+ transform = ax .transAxes ,
195+ color = '#666666' )
196+
197+ plt .tight_layout ()
198+ html_charts .append (BenchmarkChart (label = group .name , html = mpld3 .fig_to_html (fig )))
98199 plt .close (fig )
99200
100201 return html_charts
@@ -138,6 +239,11 @@ def generate_html(benchmark_runs: list[BenchmarkRun], github_repo: str, compare_
138239 timeseries = create_time_series_chart (benchmarks , github_repo )
139240 timeseries_charts_html = '\n ' .join (f'<div class="chart" data-label="{ ts .label } "><div>{ ts .html } </div></div>' for ts in timeseries )
140241
242+ explicit_groups = create_explicit_groups (benchmark_runs , compare_names )
243+
244+ bar_charts = create_grouped_bar_charts (explicit_groups )
245+ bar_charts_html = '\n ' .join (f'<div class="chart" data-label="{ bc .label } "><div>{ bc .html } </div></div>' for bc in bar_charts )
246+
141247 html_template = f"""
142248 <!DOCTYPE html>
143249 <html>
@@ -199,21 +305,72 @@ def generate_html(benchmark_runs: list[BenchmarkRun], github_repo: str, compare_
199305 width: 400px;
200306 max-width: 100%;
201307 }}
308+ details {{
309+ margin-bottom: 24px;
310+ }}
311+ summary {{
312+ font-size: 18px;
313+ font-weight: 500;
314+ cursor: pointer;
315+ padding: 12px;
316+ background: #e9ecef;
317+ border-radius: 8px;
318+ user-select: none;
319+ }}
320+ summary:hover {{
321+ background: #dee2e6;
322+ }}
202323 </style>
203324 <script>
325+ function getQueryParam(param) {{
326+ const urlParams = new URLSearchParams(window.location.search);
327+ return urlParams.get(param);
328+ }}
329+
204330 function filterCharts() {{
205331 const regexInput = document.getElementById('bench-filter').value;
206332 const regex = new RegExp(regexInput, 'i');
207333 const charts = document.querySelectorAll('.chart');
334+ let timeseriesVisible = false;
335+ let barChartsVisible = false;
336+
208337 charts.forEach(chart => {{
209338 const label = chart.getAttribute('data-label');
210339 if (regex.test(label)) {{
211340 chart.style.display = '';
341+ if (chart.closest('.timeseries')) {{
342+ timeseriesVisible = true;
343+ }} else if (chart.closest('.bar-charts')) {{
344+ barChartsVisible = true;
345+ }}
212346 }} else {{
213347 chart.style.display = 'none';
214348 }}
215349 }});
350+
351+ updateURL(regexInput);
352+
353+ document.querySelector('.timeseries').open = timeseriesVisible;
354+ document.querySelector('.bar-charts').open = barChartsVisible;
216355 }}
356+
357+ function updateURL(regex) {{
358+ const url = new URL(window.location);
359+ if (regex) {{
360+ url.searchParams.set('regex', regex);
361+ }} else {{
362+ url.searchParams.delete('regex');
363+ }}
364+ history.replaceState(null, '', url);
365+ }}
366+
367+ document.addEventListener('DOMContentLoaded', (event) => {{
368+ const regexParam = getQueryParam('regex');
369+ if (regexParam) {{
370+ document.getElementById('bench-filter').value = regexParam;
371+ filterCharts();
372+ }}
373+ }});
217374 </script>
218375 </head>
219376 <body>
@@ -222,13 +379,20 @@ def generate_html(benchmark_runs: list[BenchmarkRun], github_repo: str, compare_
222379 <div class="filter-container">
223380 <input type="text" id="bench-filter" placeholder="Regex..." oninput="filterCharts()">
224381 </div>
225- <h2>Historical Results</h2>
226- <div class="charts">
227- { timeseries_charts_html }
228- </div>
382+ <details class="timeseries">
383+ <summary>Historical Results</summary>
384+ <div class="charts">
385+ { timeseries_charts_html }
386+ </div>
387+ </details>
388+ <details class="bar-charts">
389+ <summary>Comparisons</summary>
390+ <div class="charts">
391+ { bar_charts_html }
392+ </div>
393+ </details>
229394 </div>
230395 </body>
231396 </html>
232397 """
233-
234398 return html_template
0 commit comments