Skip to content

Commit dfe102f

Browse files
ZzEeKkAaDiptorup Deb
authored andcommitted
Add comparison report
1 parent e7377b5 commit dfe102f

File tree

6 files changed

+127
-14
lines changed

6 files changed

+127
-14
lines changed

dpbench/console/_namespace.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,26 @@ class Namespace(argparse.Namespace):
2828
timeout: float
2929
precision: Union[str, None]
3030
program: str
31+
comparisons: list[str]
32+
33+
34+
class CommaSeparateStringAction(argparse.Action):
35+
"""Action that reads comma separated string into set of strings.
36+
37+
This action supposed to be used in argparse argument.
38+
"""
39+
40+
def __call__(self, _, namespace, values, __):
41+
"""Split values into set of strings."""
42+
setattr(namespace, self.dest, set(values.split(",")))
43+
44+
45+
class CommaSeparateStringListAction(argparse.Action):
46+
"""Action that reads comma separated string into set of strings.
47+
48+
This action supposed to be used in argparse argument.
49+
"""
50+
51+
def __call__(self, _, namespace, values, __):
52+
"""Split values into list of strings."""
53+
setattr(namespace, self.dest, values.split(","))

dpbench/console/entry.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,11 @@
77
import argparse
88
from importlib.metadata import version
99

10-
from ._namespace import Namespace
10+
from ._namespace import CommaSeparateStringAction, Namespace
1111
from .report import add_report_arguments, execute_report
1212
from .run import add_run_arguments, execute_run
1313

1414

15-
class CommaSeparateStringAction(argparse.Action):
16-
"""Action that reads comma separated string into set of strings.
17-
18-
This action supposed to be used in argparse argument.
19-
"""
20-
21-
def __call__(self, _, namespace, values, __):
22-
"""Split values into set of strings."""
23-
setattr(namespace, self.dest, set(values.split(",")))
24-
25-
2615
def parse_args() -> Namespace:
2716
"""Parse console arguments into dpbench Namespace."""
2817
parser = argparse.ArgumentParser()

dpbench/console/report.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import sqlalchemy
1010

11-
from ._namespace import Namespace
11+
from ._namespace import CommaSeparateStringListAction, Namespace
1212

1313

1414
def add_report_arguments(parser: argparse.ArgumentParser):
@@ -17,7 +17,16 @@ def add_report_arguments(parser: argparse.ArgumentParser):
1717
Args:
1818
parser: argument parser where arguments will be populated.
1919
"""
20-
pass
20+
parser.add_argument(
21+
"-c",
22+
"--comparisons",
23+
type=str,
24+
action=CommaSeparateStringListAction,
25+
nargs="?",
26+
default=[],
27+
help="Comma separated list of implementation pairs that need to be"
28+
+ " compared.",
29+
)
2130

2231

2332
def execute_report(args: Namespace, conn: sqlalchemy.Engine):
@@ -29,6 +38,12 @@ def execute_report(args: Namespace, conn: sqlalchemy.Engine):
2938
args: object with all input arguments.
3039
conn: database connection.
3140
"""
41+
if len(args.comparisons) % 2 != 0:
42+
raise ValueError(
43+
"--comparisons must contain pairs, but odd number of"
44+
+ " elements was provided"
45+
)
46+
3247
import dpbench.config as cfg
3348
from dpbench.infrastructure.reporter import update_run_id
3449
from dpbench.infrastructure.runner import print_report
@@ -39,9 +54,15 @@ def execute_report(args: Namespace, conn: sqlalchemy.Engine):
3954
load_implementations=False,
4055
)
4156

57+
comparison_pairs = [
58+
tuple(args.comparisons[i : i + 2])
59+
for i in range(0, len(args.comparisons), 2)
60+
]
61+
4262
args.run_id = update_run_id(conn, args.run_id)
4363
print_report(
4464
conn=conn,
4565
run_id=args.run_id,
4666
implementations=args.implementations,
67+
comparison_pairs=comparison_pairs,
4768
)

dpbench/infrastructure/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NumbaMlirFramework,
2222
)
2323
from .reporter import (
24+
generate_comparison_report,
2425
generate_impl_summary_report,
2526
generate_performance_report,
2627
get_unexpected_failures,
@@ -45,6 +46,7 @@
4546
"store_results",
4647
"generate_impl_summary_report",
4748
"generate_performance_report",
49+
"generate_comparison_report",
4850
"get_unexpected_failures",
4951
"validate",
5052
]

dpbench/infrastructure/reporter.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
__all__ = [
2626
"generate_impl_summary_report",
2727
"generate_performance_report",
28+
"generate_comparison_report",
2829
]
2930

3031

@@ -216,6 +217,74 @@ def generate_performance_report(
216217
generate_summary(df)
217218

218219

220+
def generate_comparison_report(
221+
conn: sqlalchemy.Engine,
222+
run_id: int,
223+
implementations: list[str],
224+
comparison_pairs: list[tuple[str, str]],
225+
headless=False,
226+
):
227+
"""generate comparison report with median times for each benchmark"""
228+
if len(comparison_pairs) == 0:
229+
return
230+
231+
legends = read_legends()
232+
233+
if not headless:
234+
generate_header(conn, run_id)
235+
generate_legend(legends)
236+
237+
columns = [
238+
dm.Result.input_size_human.label("input_size"),
239+
dm.Result.benchmark,
240+
dm.Result.problem_preset,
241+
]
242+
243+
for impl in implementations:
244+
columns.append(
245+
func.ifnull(
246+
func.max(
247+
case(
248+
(
249+
dm.Result.implementation == impl,
250+
dm.Result.median_exec_time,
251+
),
252+
)
253+
),
254+
None,
255+
).label(impl),
256+
)
257+
258+
sql = (
259+
sqlalchemy.select(*columns)
260+
.group_by(
261+
dm.Result.benchmark,
262+
dm.Result.problem_preset,
263+
)
264+
.where(dm.Result.run_id == run_id)
265+
)
266+
267+
df = pd.read_sql_query(
268+
sql=sql,
269+
con=conn.connect(),
270+
)
271+
272+
for index, row in df.iterrows():
273+
for target, reference in comparison_pairs:
274+
if row[reference] == 0 or row[target] == 0:
275+
boost = "n/a"
276+
else:
277+
boost = (
278+
str(round((row[target] / row[reference]) * 100, 2)) + "%"
279+
)
280+
df.at[index, target + "_to_" + reference] = boost
281+
282+
for impl in implementations:
283+
df = df.drop(impl, axis=1)
284+
285+
generate_summary(df)
286+
287+
219288
def get_failures_from_results(
220289
results_db: Union[str, sqlalchemy.Engine] = "results.db",
221290
run_id: int = None,

dpbench/infrastructure/runner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def print_report(
213213
conn: sqlalchemy.Engine,
214214
run_id: int,
215215
implementations: set[str],
216+
comparison_pairs: list[tuple[str, str]] = [],
216217
):
217218
if not implementations:
218219
implementations = {impl.postfix for impl in cfg.GLOBAL.implementations}
@@ -231,6 +232,14 @@ def print_report(
231232
headless=True,
232233
)
233234

235+
dpbi.generate_comparison_report(
236+
conn,
237+
run_id=run_id,
238+
implementations=implementations,
239+
comparison_pairs=comparison_pairs,
240+
headless=True,
241+
)
242+
234243
unexpected_failures = dpbi.get_unexpected_failures(conn, run_id=run_id)
235244

236245
if len(unexpected_failures) > 0:

0 commit comments

Comments
 (0)