Skip to content

Commit 81fe20a

Browse files
committed
Support pass filtering when recording compile time
1 parent d8d99d1 commit 81fe20a

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

util/analyze/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def sum_dicts(ds):
1111
return {k: sum(d[k] for d in ds) for k in ds[0].keys()}
1212

1313

14-
def foreach_bench(analysis_f, *logs, combine=None):
14+
def foreach_bench(analysis_f, *logs, combine=None, **kwargs):
1515
'''
1616
Repeats `analysis_f` for each benchmark in `logs`.
1717
Also computes the analysis for the entire thing.
@@ -25,11 +25,11 @@ def foreach_bench(analysis_f, *logs, combine=None):
2525
'''
2626

2727
if combine is None:
28-
combine = lambda *args: analysis_f(*logs)
28+
combine = lambda *args: analysis_f(*logs, **kwargs)
2929

3030
benchmarks = zip(*[log.benchmarks for log in logs])
3131

32-
bench_stats = {bench[0].name: analysis_f(*bench) for bench in benchmarks}
32+
bench_stats = {bench[0].name: analysis_f(*bench, **kwargs) for bench in benchmarks}
3333
total = combine(bench_stats.values())
3434

3535
return {

util/gt_analysis/gt_cmp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ def is_improved(before: Block, after: Block):
6767
return cost_for_blk(before) > cost_for_blk(after)
6868

6969

70-
def compute_stats(nogt: Logs, gt: Logs):
70+
def compute_stats(nogt: Logs, gt: Logs, *, pass_num: int):
7171
TOTAL_BLOCKS = utils.count(nogt)
7272

7373
nogt_all, gt_all = nogt, gt
7474

75+
if pass_num is not None:
76+
nogt = nogt.keep_blocks_if(lambda b: b.single('PassFinished')['num'] == pass_num)
77+
gt = gt.keep_blocks_if(lambda b: b.single('PassFinished')['num'] == pass_num)
78+
7579
NUM_PROVED_OPTIMAL_WITHOUT_ENUMERATING = utils.count(utils.zipped_keep_blocks_if(
7680
nogt, gt, pred=lambda a, b: block_stats.is_enumerated(a) and not block_stats.is_enumerated(b))[0])
7781
nogt, gt = utils.zipped_keep_blocks_if(
@@ -129,9 +133,10 @@ def compute_stats(nogt: Logs, gt: Logs):
129133
parser = argparse.ArgumentParser()
130134
parser.add_argument('nogt')
131135
parser.add_argument('gt')
136+
parser.add_argument('--pass-num', type=int, default=None, help='Which pass to analyze (default: all passes)')
132137
args = analyze.parse_args(parser, 'nogt', 'gt')
133138

134-
results = utils.foreach_bench(compute_stats, args.nogt, args.gt)
139+
results = utils.foreach_bench(compute_stats, args.nogt, args.gt, pass_num=args.pass_num)
135140

136141
writer = csv.DictWriter(sys.stdout,
137142
fieldnames=['Benchmark'] + list(results['Total'].keys()))

0 commit comments

Comments
 (0)