Skip to content

Commit 0b14d9b

Browse files
committed
Allow run specific implementations
1 parent 8fb91f2 commit 0b14d9b

File tree

2 files changed

+63
-22
lines changed

2 files changed

+63
-22
lines changed

dpbench/infrastructure/reporter.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def generate_summary(data: pd.DataFrame):
105105
def generate_impl_summary_report(
106106
results_db: Union[str, sqlalchemy.Engine] = "results.db",
107107
run_id: int = None,
108-
postfixes: list[str] = None,
108+
implementations: list[str] = None,
109109
):
110110
"""generate implementation summary report with status of each benchmark"""
111111
conn = update_connection(results_db=results_db)
@@ -121,10 +121,10 @@ def generate_impl_summary_report(
121121
dm.Result.problem_preset,
122122
]
123123

124-
if postfixes is None:
125-
postfixes = [impl.postfix for impl in cfg.GLOBAL.implementations]
124+
if implementations is None:
125+
implementations = [impl.postfix for impl in cfg.GLOBAL.implementations]
126126

127-
for impl in postfixes:
127+
for impl in implementations:
128128
columns.append(
129129
func.ifnull(
130130
func.max(
@@ -161,23 +161,28 @@ def generate_impl_summary_report(
161161
def generate_performance_report(
162162
results_db: Union[str, sqlalchemy.Engine] = "results.db",
163163
run_id: int = None,
164+
implementations: list[str] = None,
165+
headless=False,
164166
):
165167
"""generate performance report with median times for each benchmark"""
166168
conn = update_connection(results_db=results_db)
167169
run_id = update_run_id(conn, run_id)
168170
legends = read_legends()
169171

170-
generate_header(conn, run_id)
171-
generate_legend(legends)
172+
if not headless:
173+
generate_header(conn, run_id)
174+
generate_legend(legends)
172175

173176
columns = [
174177
dm.Result.input_size_human.label("input_size"),
175178
dm.Result.benchmark,
176179
dm.Result.problem_preset,
177180
]
178181

179-
for _, row in legends.iterrows():
180-
impl = row["impl_postfix"]
182+
if implementations is None:
183+
implementations = [impl.postfix for impl in cfg.GLOBAL.implementations]
184+
185+
for impl in implementations:
181186
columns.append(
182187
func.ifnull(
183188
func.max(
@@ -207,9 +212,7 @@ def generate_performance_report(
207212
)
208213

209214
for index, row in df.iterrows():
210-
for _, legend_row in legends.iterrows():
211-
impl = legend_row["impl_postfix"]
212-
215+
for impl in implementations:
213216
time = row[impl]
214217
if time:
215218
NANOSECONDS_IN_MILISECONDS: Final[float] = 1000 * 1000.0

dpbench/runner.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,28 @@ def run_benchmark(
7878
print_results=True,
7979
run_id: int = None,
8080
):
81+
"""Run specific benchmark.
82+
83+
Args:
84+
bname (str, semi-optional): Name of the benchmark. Either name, either
85+
configuration must be provided.
86+
benchmark (Benchmark, semi-optional): Benchmark configuration. Either
87+
name, either configuration must be provided.
88+
implementation_postfix: (str, optional): Implementation postfixes
89+
to be executed. If not provided, all possible implementations will
90+
be executed.
91+
preset (str, optional): Problem size. Defaults to "S".
92+
repeat (int, optional): Number of repetitions. Defaults to 1.
93+
validate (bool, optional): Whether to validate against NumPy.
94+
Defaults to True.
95+
timeout (float, optional): Timeout setting. Defaults to 10.0.
96+
conn: connection to database. If not provided results won't be stored.
97+
print_results (bool, optional): Either print results. Defaults to True.
98+
run_id (int, optional): Either store result to specific run_id.
99+
If not provided, new run_id will be created.
100+
101+
Returns: nothing.
102+
"""
81103
bench_cfg = get_benchmark(benchmark=benchmark, benchmark_name=bname)
82104
bname = bench_cfg.name
83105
print("")
@@ -119,37 +141,45 @@ def run_benchmarks(
119141
repeat=10,
120142
validate=True,
121143
timeout=200.0,
122-
dbfile=None,
123144
print_results=True,
124145
run_id=None,
146+
implementations: list[str] = None,
125147
):
126148
"""Run all benchmarks in the dpbench benchmark directory
149+
127150
Args:
128-
bconfig_path (str, optional): Path to benchmark configurations.
129-
Defaults to None.
130151
preset (str, optional): Problem size. Defaults to "S".
131152
repeat (int, optional): Number of repetitions. Defaults to 1.
132153
validate (bool, optional): Whether to validate against NumPy.
133-
Defaults to True.
154+
Defaults to True.
134155
timeout (float, optional): Timeout setting. Defaults to 10.0.
156+
print_results (bool, optional): Either print results. Defaults to True.
157+
run_id (int, optional): Either store result to specific run_id.
158+
If not provided, new run_id will be created.
159+
implementations: (list[str], optional): List of implementation postfixes
160+
to be executed. If not provided, all possible implementations will
161+
be executed.
162+
163+
Returns: nothing.
135164
"""
136165

137166
print("===============================================================")
138167
print("")
139168
print("***Start Running DPBench***")
140-
if not dbfile:
141-
dbfile = "results.db"
142169

143170
dpbi.create_results_table()
144-
conn = dpbi.create_connection(db_file=dbfile)
171+
conn = dpbi.create_connection(db_file="results.db")
145172
if run_id is None:
146173
run_id = dpbi.create_run(conn)
147174

175+
if implementations is None:
176+
implementations = [impl.postfix for impl in cfg.GLOBAL.implementations]
177+
148178
for b in cfg.GLOBAL.benchmarks:
149-
for impl in cfg.GLOBAL.implementations:
179+
for impl in implementations:
150180
run_benchmark(
151181
benchmark=b,
152-
implementation_postfix=impl.postfix,
182+
implementation_postfix=impl,
153183
preset=preset,
154184
repeat=repeat,
155185
validate=validate,
@@ -167,6 +197,14 @@ def run_benchmarks(
167197
print("===============================================================")
168198
print("")
169199

170-
dpbi.generate_impl_summary_report(conn, run_id=run_id)
200+
if print_results:
201+
dpbi.generate_impl_summary_report(
202+
conn, run_id=run_id, implementations=implementations
203+
)
171204

172-
return dbfile
205+
dpbi.generate_performance_report(
206+
conn,
207+
run_id=run_id,
208+
implementations=implementations,
209+
headless=True,
210+
)

0 commit comments

Comments
 (0)