Skip to content

Commit 8d305a8

Browse files
committed
Add plaidml total compile time analysis
1 parent 6dd293c commit 8d305a8

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

util/analyze/lib/compile_times.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ def sched_time(logs):
2222

2323
_CPU2017_TIME_ELAPSED = re.compile(r"Elapsed compile for '(?P<bench>[^']+)': \S+ \((?P<elapsed>\d+)\)")
2424
_BACKUP_TIME_ELAPSED = re.compile(r'(?P<elapsed>\d+) total seconds elapsed')
25+
_PLAIDML_TIME_ELAPSED = re.compile(
26+
r'Example finished, elapsed: (?P<elapsed>\S+)s \(compile\), (?P<exec>\S+)s \(execution\)')
27+
28+
29+
def plaidml_total_compile_time_seconds(logs):
30+
try:
31+
return sum(float(_PLAIDML_TIME_ELAPSED.search(bench.blocks[-1].raw_log)['elapsed']) for bench in logs.benchmarks)
32+
except TypeError:
33+
raise KeyError('Logs must contain "Example finished, elapsed:" output by the PlaidML benchmark suite')
2534

2635

2736
def total_compile_time_seconds(logs):
@@ -33,23 +42,34 @@ def total_compile_time_seconds(logs):
3342
if m:
3443
if len(m) != 1:
3544
logging.warning('Multiple CPU2017 elapsed time indicators. Using the first one out of: %s', m)
36-
return m[0]['elapsed']
45+
return int(m[0]['elapsed'])
3746

3847
m = _BACKUP_TIME_ELAPSED.search(last_logs)
3948
assert m, \
4049
'Logs must contain "total seconds elapsed" output by the SPEC benchmark suite'
4150

42-
return m['elapsed']
51+
return int(m['elapsed'])
52+
53+
54+
def total_compile_time_seconds_f(benchsuite):
55+
return {
56+
'spec': total_compile_time_seconds,
57+
'plaidml': plaidml_total_compile_time_seconds
58+
}[benchsuite]
4359

4460

4561
if __name__ == '__main__':
4662
parser = argparse.ArgumentParser()
47-
parser.add_argument('--variant', choices=('sched', 'total'),
63+
parser.add_argument('--variant', choices=('sched', 'total', 'plaidml'),
4864
help='Which timing variant to use')
4965
parser.add_argument('logs', help='The logs to analyze')
5066
args = analyze.parse_args(parser, 'logs')
5167

52-
fn = total_compile_time_seconds if args.variant == 'total' else sched_time
68+
fn = {
69+
'sched': sched_time,
70+
'total': total_compile_time_seconds,
71+
'plaidml': plaidml_total_compile_time_seconds,
72+
}[args.variant]
5373
results = foreach_bench(fn, args.logs, combine=sum)
5474
writer = csv.DictWriter(sys.stdout, fieldnames=results.keys())
5575
writer.writeheader()

util/gt_analysis/gt_cmp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from analyze.lib import block_stats, compile_times
99

1010
sched_time = compile_times.sched_time
11-
total_compile_time_seconds = compile_times.total_compile_time_seconds
1211

1312

1413
def blocks_enumerated_optimally(blocks):
@@ -87,9 +86,7 @@ def blk_relative_cost(nogt, gt) -> Tuple[int, int]:
8786
return no_sum, yes_sum
8887

8988

90-
def compute_stats(nogt: Logs, gt: Logs, *, pass_num: int):
91-
TOTAL_BLOCKS = utils.count(nogt)
92-
89+
def compute_stats(nogt: Logs, gt: Logs, *, pass_num: int, total_compile_time_seconds):
9390
nogt_all, gt_all = nogt, gt
9491

9592
if pass_num is not None:
@@ -106,7 +103,7 @@ def compute_stats(nogt: Logs, gt: Logs, *, pass_num: int):
106103
nogt_rel, gt_rel = blk_relative_cost(nogt, gt)
107104

108105
result = {
109-
'Total Blocks in Benchsuite': TOTAL_BLOCKS,
106+
'Total Blocks in Benchsuite': utils.count(nogt_all),
110107
'Num Blocks enumerated with & without GT': utils.count(nogt),
111108
'Num Blocks proved optimal just by GT': NUM_PROVED_OPTIMAL_WITHOUT_ENUMERATING,
112109

@@ -158,7 +155,11 @@ def compute_stats(nogt: Logs, gt: Logs, *, pass_num: int):
158155
parser.add_argument('--pass-num', type=int, default=None, help='Which pass to analyze (default: all passes)')
159156
args = analyze.parse_args(parser, 'nogt', 'gt')
160157

161-
results = utils.foreach_bench(compute_stats, args.nogt, args.gt, pass_num=args.pass_num)
158+
results = utils.foreach_bench(
159+
compute_stats, args.nogt, args.gt,
160+
pass_num=args.pass_num,
161+
total_compile_time_seconds=compile_times.total_compile_time_seconds_f(args.benchsuite),
162+
)
162163

163164
writer = csv.DictWriter(sys.stdout,
164165
fieldnames=['Benchmark'] + list(results['Total'].keys()))

0 commit comments

Comments
 (0)