Skip to content

Commit 3eff4cb

Browse files
author
Malmahrouqi3
committed
bench_diff improve
1 parent 4f0704f commit 3eff4cb

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

toolchain/mfc/args.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def add_common_arguments(p: argparse.ArgumentParser, mask = None):
138138
add_common_arguments(bench_diff, "t")
139139
bench_diff.add_argument("lhs", metavar="LHS", type=str, help="Path to a benchmark result YAML file.")
140140
bench_diff.add_argument("rhs", metavar="RHS", type=str, help="Path to a benchmark result YAML file.")
141+
bench_diff.add_argument("-f", "--file", metavar="FILE", type=str, required=False, default=None, help="Path to the data file (e.g., data.js).")
142+
bench_diff.add_argument("-n", "--name", metavar="NAME", nargs="+", type=str, required=False, default=[], help="Test name (e.g. GT Phoenix (CPU)).")
141143

142144
# COUNT
143145
add_common_arguments(count, "g")
@@ -170,9 +172,6 @@ def add_common_arguments(p: argparse.ArgumentParser, mask = None):
170172
parser.print_help()
171173
exit(-1)
172174

173-
# "Slugify" the name of the job
174-
args["name"] = re.sub(r'[\W_]+', '-', args["name"])
175-
176175
# We need to check for some invalid combinations of arguments because of
177176
# the limitations of argparse.
178177
if args["command"] == "build":
@@ -181,6 +180,11 @@ def add_common_arguments(p: argparse.ArgumentParser, mask = None):
181180
if args["command"] == "run":
182181
if args["binary"] is not None and args["engine"] != "interactive":
183182
raise MFCException("./mfc.sh run's --binary can only be used with --engine=interactive.")
183+
if isinstance(args["name"], str):
184+
# "Slugify" the name of the job
185+
args["name"] = re.sub(r'[\W_]+', '-', args["name"]).strip('-')
186+
elif args["command"] == "bench_diff" and len(args["name"]) > 0:
187+
args["name"] = " ".join(args["name"])
184188

185189
# Input files to absolute paths
186190
for e in ["input", "input1", "input2"]:

toolchain/mfc/bench.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os, sys, uuid, subprocess, dataclasses, typing, math
1+
import os, sys, uuid, subprocess, dataclasses, typing, math, json
22

33
import rich.table
44

@@ -106,16 +106,51 @@ def diff():
106106
Using intersection: {slugs} with {len(slugs)} elements.
107107
""")
108108

109+
cb_stats = {}
110+
if ARG("file") is not None and ARG("name") is not None:
111+
try:
112+
with open(ARG("file"), 'r') as f:
113+
data_json = json.load(f)
114+
115+
cb_test = ARG("name")
116+
if "entries" in data_json and cb_test in data_json["entries"]:
117+
benchmark_runs = data_json["entries"][cb_test]
118+
case_times = {}
119+
for run in benchmark_runs:
120+
if "benches" not in run:
121+
continue
122+
for bench in run["benches"]:
123+
case_name = bench.get("name")
124+
grind_value = bench.get("value")
125+
if case_name is None or grind_value is None:
126+
continue
127+
if case_name not in case_times:
128+
case_times[case_name] = []
129+
case_times[case_name].append(grind_value)
130+
for case_name, values in case_times.items():
131+
if len(values) > 0:
132+
avg = sum(values) / len(values)
133+
cb_stats[case_name] = {"avg": avg, "count": len(values)}
134+
135+
cons.print(f"[bold]Loaded cb data for test: [bold]{cb_test}[/bold] ({len(cb_stats)} cases)[/bold]")
136+
else:
137+
cons.print(f"[bold yellow]Warning[/bold yellow]: Test '[bold]{cb_test}[/bold]' not found in data file.")
138+
except Exception as e:
139+
cons.print(f"[bold yellow]Warning[/bold yellow]: Could not load data file: {e}")
140+
109141
table = rich.table.Table(show_header=True, box=rich.table.box.SIMPLE)
110142
table.add_column("[bold]Case[/bold]", justify="left")
111143
table.add_column("[bold]Pre Process[/bold]", justify="right")
112144
table.add_column("[bold]Simulation[/bold]", justify="right")
113145
table.add_column("[bold]Post Process[/bold]", justify="right")
146+
if cb_stats:
147+
table.add_column("[bold] CB (Grind)[/bold]", justify="right")
114148

115149
err = 0
116150
for slug in slugs:
117151
lhs_summary, rhs_summary = lhs["cases"][slug]["output_summary"], rhs["cases"][slug]["output_summary"]
118152
speedups = ['N/A', 'N/A', 'N/A']
153+
grind_comparison = 'N/A'
119154

120155
for i, target in enumerate(sorted(DEFAULT_TARGETS, key=lambda t: t.runOrder)):
121156
if (target.name not in lhs_summary) or (target.name not in rhs_summary):
@@ -140,11 +175,23 @@ def diff():
140175
if grind_time_value <0.95:
141176
cons.print(f"[bold red]Error[/bold red]: Benchmarking failed since grind time speedup for {target.name} below acceptable threshold (<0.95) - Case: {slug}")
142177
err = 1
178+
if slug in cb_stats:
179+
rhs_grind = rhs_summary[target.name]["grind"]
180+
stats = cb_stats[slug]
181+
avg = stats["avg"]
182+
offset_pct = ((rhs_grind - avg) / avg * 100)
183+
color = "red" if (offset_pct) > 0 else "yellow" if abs(offset_pct) == 0 else "green"
184+
grind_comparison = f"[{color}]{offset_pct:+.2f}%[/{color}] (avg: {avg:.2f})"
185+
143186
except Exception as _:
144187
pass
145188

146-
table.add_row(f"[magenta]{slug}[/magenta]", *speedups)
189+
row = [f"[magenta]{slug}[/magenta]", *speedups]
190+
if cb_stats:
191+
row.append(grind_comparison)
192+
table.add_row(*row)
147193

148194
cons.raw.print(table)
149195
if err:
150196
raise MFCException("Benchmarking failed")
197+

0 commit comments

Comments
 (0)