1212import analyze
1313from analyze import Logs
1414
15+ # Some reference point to compute occupancies against.
16+ # This would ideally be the maximum possible occupancy so that the .cost property will never be negative
17+ OCCUPANCY_REFERENCE_POINT = 10
18+
19+ SPILL_COST_WEIGHT = 0
20+
1521
1622@dataclass
1723class DagInfo :
@@ -24,10 +30,20 @@ class DagInfo:
2430 relative_cost : int
2531 length : int
2632 is_optimal : bool
33+ # Spill cost is not absolute for SCF = TARGET. By recording the baseline, we can adjust the costs.
34+ target_occupancy : Optional [int ]
35+ spill_cost : int
2736
2837 @property
2938 def cost (self ):
30- return self .lower_bound + self .relative_cost
39+ cost = self .lower_bound + self .relative_cost
40+ if self .target_occupancy is not None :
41+ # TargetOcc - SC is a "complement"-like operation, meaning that it undoes itself.
42+ actual_occupancy = self .target_occupancy - self .spill_cost
43+ absolute_spill_cost = OCCUPANCY_REFERENCE_POINT - actual_occupancy
44+ cost += SPILL_COST_WEIGHT * absolute_spill_cost
45+
46+ return cost
3147
3248
3349class MismatchKind (Enum ):
@@ -133,6 +149,8 @@ def extract_dag_info(logs: Logs) -> Dict[str, List[List[DagInfo]]]:
133149 print (block .raw_log )
134150 exit (2 )
135151
152+ target_occ = block .single ('TargetOccupancy' )['target' ] if 'TargetOccupancy' in block else None
153+
136154 dags .setdefault (block .name , []).append (DagInfo (
137155 id = block .name ,
138156 benchmark = block .benchmark ,
@@ -142,6 +160,8 @@ def extract_dag_info(logs: Logs) -> Dict[str, List[List[DagInfo]]]:
142160 relative_cost = best_result ['cost' ],
143161 length = best_result ['length' ],
144162 is_optimal = is_optimal ,
163+ spill_cost = best_result ['spill_cost' ],
164+ target_occupancy = target_occ ,
145165 ))
146166
147167 for k , block_passes in dags .items ():
@@ -338,6 +358,8 @@ def print_small_summary(mismatches: List[Mismatch]):
338358
339359 parser .add_argument ('-q' , '--quiet' , action = 'store_true' ,
340360 help = 'Only print mismatch info, and only if there are mismatches' )
361+ parser .add_argument ('--scw' , '--spill-cost-weight' , type = int , required = True ,
362+ help = 'The weight of the spill cost in the cost calculation. Only relevant if the reported spill costs are not absolute (e.g. SCF = TARGET); put any value otherwise.' , dest = 'spill_cost_weight' , metavar = 'SCW' )
341363 parser .add_argument ('--no-summarize-largest-cost-difference' , action = 'store_true' ,
342364 help = 'Do not summarize the mismatches with the biggest difference in cost' )
343365 parser .add_argument ('--no-summarize-smallest-mismatches' , action = 'store_true' ,
@@ -358,6 +380,7 @@ def print_small_summary(mismatches: List[Mismatch]):
358380 NUM_SMALLEST_BLOCKS_PRINT = args .num_smallest_mismatches_print
359381 MISSING_LOWER_BOUND_DUMP_COUNT = args .missing_lb_dump_count
360382 MISSING_LOWER_BOUND_DUMP_LINES = args .missing_lb_dump_lines
383+ SPILL_COST_WEIGHT = args .spill_cost_weight
361384
362385 main (
363386 args .first , args .second ,
0 commit comments