88
99import math
1010import sys
11+ from dataclasses import dataclass
1112from enum import Enum
1213from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
1314
@@ -72,6 +73,25 @@ class TimeScale(Enum):
7273}
7374
7475
76+ class NodeSource (Enum ):
77+ AOT = 1
78+ RUNTIME = 2
79+
80+
81+ @dataclass
82+ class NodeData :
83+ """
84+ Each node in the graph is an instance of NodeData, which contains:
85+ - source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
86+ - debug_handle: A tuple representing the unique identifier for the output.
87+ - output: The actual output data associated with the debug handle.
88+ """
89+
90+ source : NodeSource
91+ debug_handle : tuple [int ]
92+ output : Any
93+
94+
7595def calculate_time_scale_factor (
7696 source_time_scale : TimeScale , target_time_scale : TimeScale
7797) -> float :
@@ -489,7 +509,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489509 """
490510 Merge overlapping debug handles int a single key
491511 """
492- if not intermediate_outputs :
512+ if len ( intermediate_outputs ) == 0 :
493513 return
494514 # Extract and normalize into (start, end, val)
495515 intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
@@ -512,3 +532,161 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512532 intermediate_outputs .clear ()
513533 for start , end , val in merged_intermediate_outputs :
514534 intermediate_outputs [tuple (range (start , end + 1 ))] = val
535+
536+
537+ def _debug_handles_have_overlap (
538+ aot_debug_hanlde : Tuple [int , ...], runtime_debug_handle : Tuple [int , ...]
539+ ) -> bool :
540+ """
541+ Check if the AOT debug handle and the runtime debug handle have any overlap.
542+ """
543+ aot_set = set (aot_debug_hanlde )
544+ runtime_set = set (runtime_debug_handle )
545+ return len (aot_set .intersection (runtime_set )) > 0
546+
547+
548+ def _combine_debug_hanldes (debug_handles : List [Tuple [int , ...]]) -> Tuple [int , ...]:
549+ """Combine multiple debug handles into one debug handle"""
550+ combined_debug_handles_set = set ()
551+ for debug_handle in debug_handles :
552+ combined_debug_handles_set .update (set (debug_handle ))
553+ return tuple (sorted (combined_debug_handles_set ))
554+
555+
556+ def _combine_overlapped_intermediate_outputs (
557+ nodes : List [Tuple [Tuple [int , ...], Any ]]
558+ ) -> Tuple [Tuple [int , ...], Any ]:
559+ """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
560+ debug_handles = [debug_handle for debug_handle , _ in nodes ]
561+ outputs = [output for _ , output in nodes ]
562+ combined_debug_handle = _combine_debug_hanldes (debug_handles )
563+ output = outputs [- 1 ] # Pick the last one
564+ return combined_debug_handle , output
565+
566+
567+ def _create_debug_handle_overlap_graph (
568+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
569+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
570+ ) -> Tuple [List [NodeData ], Dict [int , List [int ]]]:
571+ """
572+ Create a graph representing overlapping debug handles between AOT and runtime outputs.
573+
574+ Edges in the graph are represented as a dictionary where:
575+ - The key is the index of a node in the nodes list.
576+ - The value is a list of indices of nodes that have overlapping debug handles with the key node.
577+
578+ Returns:
579+ - A tuple containing:
580+ - A list of NodeData instances representing the nodes in the graph.
581+ - A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
582+ """
583+ nodes = []
584+ for debug_handle , output in aot_intermediate_outputs .items ():
585+ nodes .append (NodeData (NodeSource .AOT , debug_handle , output ))
586+ for debug_handle , output in runtime_intermediate_outputs .items ():
587+ nodes .append (NodeData (NodeSource .RUNTIME , debug_handle , output ))
588+
589+ edges = {i : [] for i in range (len (nodes ))}
590+ for i in range (len (nodes )):
591+ for j in range (i + 1 , len (nodes )):
592+ node_i = nodes [i ]
593+ node_j = nodes [j ]
594+ # Only connect nodes from different sources(aot vs runtime) that overlap
595+ if node_i .source != node_j .source and _debug_handles_have_overlap (
596+ node_i .debug_handle , node_j .debug_handle
597+ ):
598+ edges [i ].append (j )
599+ edges [j ].append (i )
600+ return (nodes , edges )
601+
602+
603+ def _find_connected_components (
604+ nodes : List [NodeData ], edges : Dict [int , List [int ]]
605+ ) -> List [List [int ]]:
606+ """
607+ Find groups of connected nodes in a graph using DFS.
608+ Parameters:
609+ - nodes: A list of nodes in the graph.
610+ - edges: A dictionary where each key is a node index, and the value is a list
611+ of indices of connected nodes.
612+ Returns:
613+ - A list of connected components, each represented as a list of node indices.
614+ """
615+ visited = [False ] * len (nodes )
616+ connected_components = []
617+
618+ def dfs (node_id , component ):
619+ visited [node_id ] = True
620+ component .append (node_id )
621+ # Iterate over all neighbors of the current node
622+ for neighbor_node_id in edges [node_id ]:
623+ # If a neighbor has not been visited yet, recursively visit it
624+ if not visited [neighbor_node_id ]:
625+ dfs (neighbor_node_id , component )
626+
627+ # Perform DFS on all nodes to find connected components
628+ for i in range (len (nodes )):
629+ # If a node has not been visited yet, start a new DFS from it
630+ if not visited [i ]:
631+ component = []
632+ dfs (i , component )
633+ # After visiting all reachable nodes, add the current component to the list
634+ connected_components .append (component )
635+ return connected_components
636+
637+
638+ def map_runtime_aot_intermediate_outputs (
639+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
640+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
641+ ) -> Dict [Tuple [Tuple [int , ...], Any ], Tuple [Tuple [int , ...], Any ]]:
642+ """
643+ Map the runtime intermediate outputs to the AOT intermediate outputs
644+ by finding overlapping debug handles and combining them into a single debug_handle
645+
646+ Returns:
647+ Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
648+ from runtime intermediate output to AOT intermediate output
649+ """
650+ # Merge overlapping debug handles
651+ merge_overlapping_debug_handles (aot_intermediate_outputs )
652+ merge_overlapping_debug_handles (runtime_intermediate_outputs )
653+
654+ # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
655+ nodes , edges = _create_debug_handle_overlap_graph (
656+ aot_intermediate_outputs , runtime_intermediate_outputs
657+ )
658+ # Find connected(between aot and runtime) components
659+ connected_components = _find_connected_components (nodes , edges )
660+
661+ aot_runtime_mapping = {}
662+ for comp in connected_components :
663+ # Separate nodes into AOT and runtime lists based on their source,
664+ # each list is combined into a single element and mapped to each other.
665+ aot_list = [
666+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
667+ for node_id in comp
668+ if nodes [node_id ].source == NodeSource .AOT
669+ ]
670+ runtime_list = [
671+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
672+ for node_id in comp
673+ if nodes [node_id ].source == NodeSource .RUNTIME
674+ ]
675+
676+ # Map only if both AOT and runtime data are present.
677+ if len (aot_list ) != 0 and len (runtime_list ) != 0 :
678+ # Combine aot debug handles into a single key
679+ aot_combined_debug_handle , aot_output = (
680+ _combine_overlapped_intermediate_outputs (aot_list )
681+ )
682+ # Combine runtime debug handles into a single key
683+ runtime_combined_debug_handle , runtime_output = (
684+ _combine_overlapped_intermediate_outputs (runtime_list )
685+ )
686+ # Create a mapping between runtime and aot
687+ aot_runtime_mapping [(aot_combined_debug_handle , aot_output )] = (
688+ runtime_combined_debug_handle ,
689+ runtime_output ,
690+ )
691+
692+ return aot_runtime_mapping
0 commit comments