@@ -67,11 +67,15 @@ def is_improved(before: Block, after: Block):
67
67
return cost_for_blk (before ) > cost_for_blk (after )
68
68
69
69
70
- def compute_stats (nogt : Logs , gt : Logs ):
70
+ def compute_stats (nogt : Logs , gt : Logs , * , pass_num : int ):
71
71
TOTAL_BLOCKS = utils .count (nogt )
72
72
73
73
nogt_all , gt_all = nogt , gt
74
74
75
+ if pass_num is not None :
76
+ nogt = nogt .keep_blocks_if (lambda b : b .single ('PassFinished' )['num' ] == pass_num )
77
+ gt = gt .keep_blocks_if (lambda b : b .single ('PassFinished' )['num' ] == pass_num )
78
+
75
79
NUM_PROVED_OPTIMAL_WITHOUT_ENUMERATING = utils .count (utils .zipped_keep_blocks_if (
76
80
nogt , gt , pred = lambda a , b : block_stats .is_enumerated (a ) and not block_stats .is_enumerated (b ))[0 ])
77
81
nogt , gt = utils .zipped_keep_blocks_if (
@@ -129,9 +133,10 @@ def compute_stats(nogt: Logs, gt: Logs):
129
133
parser = argparse .ArgumentParser ()
130
134
parser .add_argument ('nogt' )
131
135
parser .add_argument ('gt' )
136
+ parser .add_argument ('--pass-num' , type = int , default = None , help = 'Which pass to analyze (default: all passes)' )
132
137
args = analyze .parse_args (parser , 'nogt' , 'gt' )
133
138
134
- results = utils .foreach_bench (compute_stats , args .nogt , args .gt )
139
+ results = utils .foreach_bench (compute_stats , args .nogt , args .gt , pass_num = args . pass_num )
135
140
136
141
writer = csv .DictWriter (sys .stdout ,
137
142
fieldnames = ['Benchmark' ] + list (results ['Total' ].keys ()))
0 commit comments