@@ -73,6 +73,8 @@ class TimeScale(Enum):
7373 TimeScale .CYCLES : 1 ,
7474}
7575
76+ DebugHandle : TypeAlias = Tuple [int , ...]
77+
7678
7779class NodeSource (Enum ):
7880 AOT = 1
@@ -528,7 +530,7 @@ def compare_results(
528530 return results
529531
530532
531- def merge_overlapping_debug_handles (intermediate_outputs : Dict [Tuple [ int , ...] , Any ]):
533+ def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]):
532534 """
533535 Merge overlapping debug handles int a single key
534536 """
@@ -558,7 +560,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
558560
559561
560562def _debug_handles_have_overlap (
561- aot_debug_hanlde : Tuple [ int , ...], runtime_debug_handle : Tuple [ int , ...]
563+ aot_debug_hanlde : DebugHandle , runtime_debug_handle : DebugHandle
562564) -> bool :
563565 """
564566 Check if the AOT debug handle and the runtime debug handle have any overlap.
@@ -568,7 +570,7 @@ def _debug_handles_have_overlap(
568570 return len (aot_set .intersection (runtime_set )) > 0
569571
570572
571- def _combine_debug_hanldes (debug_handles : List [Tuple [ int , ...]] ) -> Tuple [ int , ...] :
573+ def _combine_debug_hanldes (debug_handles : List [DebugHandle ] ) -> DebugHandle :
572574 """Combine multiple debug handles into one debug handle"""
573575 combined_debug_handles_set = set ()
574576 for debug_handle in debug_handles :
@@ -577,8 +579,8 @@ def _combine_debug_hanldes(debug_handles: List[Tuple[int, ...]]) -> Tuple[int, .
577579
578580
579581def _combine_overlapped_intermediate_outputs (
580- nodes : List [Tuple [Tuple [ int , ...] , Any ]]
581- ) -> Tuple [Tuple [ int , ...] , Any ]:
582+ nodes : List [Tuple [DebugHandle , Any ]]
583+ ) -> Tuple [DebugHandle , Any ]:
582584 """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
583585 debug_handles = [debug_handle for debug_handle , _ in nodes ]
584586 outputs = [output for _ , output in nodes ]
@@ -588,8 +590,8 @@ def _combine_overlapped_intermediate_outputs(
588590
589591
590592def _create_debug_handle_overlap_graph (
591- aot_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
592- runtime_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
593+ aot_intermediate_outputs : Dict [DebugHandle , Any ],
594+ runtime_intermediate_outputs : Dict [DebugHandle , Any ],
593595) -> Tuple [List [NodeData ], Dict [int , List [int ]]]:
594596 """
595597 Create a graph representing overlapping debug handles between AOT and runtime outputs.
@@ -659,15 +661,15 @@ def dfs(node_id, component):
659661
660662
661663def map_runtime_aot_intermediate_outputs (
662- aot_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
663- runtime_intermediate_outputs : Dict [Tuple [ int , ...] , Any ],
664- ) -> Dict [Tuple [Tuple [ int , ...], Any ], Tuple [Tuple [ int , ...] , Any ]]:
664+ aot_intermediate_outputs : Dict [DebugHandle , Any ],
665+ runtime_intermediate_outputs : Dict [DebugHandle , Any ],
666+ ) -> Dict [Tuple [DebugHandle , Any ], Tuple [DebugHandle , Any ]]:
665667 """
666668 Map the runtime intermediate outputs to the AOT intermediate outputs
667669 by finding overlapping debug handles and combining them into a single debug_handle
668670
669671 Returns:
670- Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...] , Any]] - Mapping
672+ Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle , Any]] - Mapping
671673 from runtime intermediate output to AOT intermediate output
672674 """
673675 # Merge overlapping debug handles
@@ -760,13 +762,13 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
760762
761763def get_aot_debug_handle_to_op_name_mapping (
762764 graph_module : torch .fx .GraphModule ,
763- ) -> Dict [Tuple [ int , ...] , str ]:
765+ ) -> Dict [DebugHandle , str ]:
764766 """
765767 Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module.
766768 Parameters:
767769 graph_module (torch.fx.GraphModule): The graph module to get the mapping from.
768770 Returns:
769- Dict[Tuple[int, ...] , str]: A dictionary mapping debug handles to operator names.
771+ Dict[DebugHandle , str]: A dictionary mapping debug handles to operator names.
770772 """
771773 node_filters = [
772774 NodeFilter ("debug_handle" , "call_function" , exclude_ops = ["getitem" ])
@@ -787,8 +789,8 @@ def get_aot_debug_handle_to_op_name_mapping(
787789
788790
789791def find_op_names (
790- target_debug_handle : Tuple [ int , ...] ,
791- debug_handle_to_op_name : Dict [Tuple [ int , ...] , str ],
792+ target_debug_handle : DebugHandle ,
793+ debug_handle_to_op_name : Dict [DebugHandle , str ],
792794) -> List [str ]:
793795 """
794796 Record the operator names only if their debug handles are part of the target debug handle.
0 commit comments