88
99import math
1010import sys
11+ from dataclasses import dataclass
1112from enum import Enum
1213from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
1314
@@ -72,6 +73,17 @@ class TimeScale(Enum):
7273}
7374
7475
76+ FROM_AOT = 1
77+ FROM_RUNTIME = 2
78+
79+
80+ @dataclass
81+ class node_data :
82+ source : int
83+ debug_handle : tuple [int ]
84+ output : Any
85+
86+
7587def calculate_time_scale_factor (
7688 source_time_scale : TimeScale , target_time_scale : TimeScale
7789) -> float :
@@ -489,7 +501,7 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
489501 """
490502 Merge overlapping debug handles int a single key
491503 """
492- if not intermediate_outputs :
504+ if len ( intermediate_outputs ) == 0 :
493505 return
494506 # Extract and normalize into (start, end, val)
495507 intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
@@ -512,3 +524,166 @@ def merge_overlapping_debug_handles(intermediate_outputs: Dict[Tuple[int, ...],
512524 intermediate_outputs .clear ()
513525 for start , end , val in merged_intermediate_outputs :
514526 intermediate_outputs [tuple (range (start , end + 1 ))] = val
527+
528+
529+ def _debug_handles_have_overlap (
530+ aot_debug_hanlde : Tuple [int , ...], runtime_debug_handle : Tuple [int , ...]
531+ ) -> bool :
532+ """
533+ Check if the AOT debug handle and the runtime debug handle have any overlap.
534+ """
535+ aot_set = set (aot_debug_hanlde )
536+ runtime_set = set (runtime_debug_handle )
537+ return len (aot_set .intersection (runtime_set )) > 0
538+
539+
540+ def _combine_debug_hanldes (debug_handles : List [Tuple [int , ...]]) -> Tuple [int , ...]:
541+ """Combine multiple debug handles into one debug handle"""
542+ combined_debug_handles_set = set ()
543+ for debug_handle in debug_handles :
544+ combined_debug_handles_set .update (set (debug_handle ))
545+ return tuple (sorted (combined_debug_handles_set ))
546+
547+
548+ def _combine_overlapped_intermediate_outputs (
549+ nodes : List [Tuple [Tuple [int , ...], Any ]]
550+ ) -> Tuple [Tuple [int , ...], Any ]:
551+ """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
552+ debug_handles = [debug_handle for debug_handle , _ in nodes ]
553+ outputs = [output for _ , output in nodes ]
554+ combined_debug_handle = _combine_debug_hanldes (debug_handles )
555+ output = outputs [- 1 ] # Pick the last one
556+ return combined_debug_handle , output
557+
558+
559+ def _create_debug_handle_overlap_graph (
560+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
561+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
562+ ) -> Tuple [List [node_data ], Dict [int , List [int ]]]:
563+ """
564+ Create a graph representing overlapping debug handles between AOT and runtime outputs.
565+
566+ Each node in the graph is an instance of NodeData, which contains:
567+ - source: A string indicating the origin of the node (either FROM_AOT or FROM_RUNTIME).
568+ - debug_handle: A tuple representing the unique identifier for the output.
569+ - output: The actual output data associated with the debug handle.
570+
571+ Edges in the graph are represented as a dictionary where:
572+ - The key is the index of a node in the nodes list.
573+ - The value is a list of indices of nodes that have overlapping debug handles with the key node.
574+
575+ Returns:
576+ - A tuple containing:
577+ - A list of NodeData instances representing the nodes in the graph.
578+ - A dictionary representing the edges, where each key-value pair indicates connected nodes due to overlapping debug handles.
579+ """
580+ nodes = []
581+ for debug_handle , output in aot_intermediate_outputs .items ():
582+ nodes .append (node_data (FROM_AOT , debug_handle , output ))
583+ for debug_handle , output in runtime_intermediate_outputs .items ():
584+ nodes .append (node_data (FROM_RUNTIME , debug_handle , output ))
585+
586+ edges = {i : [] for i in range (len (nodes ))}
587+ for i in range (len (nodes )):
588+ for j in range (i + 1 , len (nodes )):
589+ node_i = nodes [i ]
590+ node_j = nodes [j ]
591+ # Only connect nodes from different sources(aot vs runtime) that overlap
592+ if node_i .source != node_j .source and _debug_handles_have_overlap (
593+ node_i .debug_handle , node_j .debug_handle
594+ ):
595+ edges [i ].append (j )
596+ edges [j ].append (i )
597+ return (nodes , edges )
598+
599+
600+ def _find_connected_components (
601+ nodes : List [node_data ], edges : Dict [int , List [int ]]
602+ ) -> List [List [int ]]:
603+ """
604+ Find groups of connected nodes in a graph using DFS.
605+ Parameters:
606+ - nodes: A list of nodes in the graph.
607+ - edges: A dictionary where each key is a node index, and the value is a list
608+ of indices of connected nodes.
609+ Returns:
610+ - A list of connected components, each represented as a list of node indices.
611+ """
612+ visited = [False ] * len (nodes )
613+ connected_components = []
614+
615+ def dfs (node_id , component ):
616+ visited [node_id ] = True
617+ component .append (node_id )
618+ # Iterate over all neighbors of the current node
619+ for neighbor_node_id in edges [node_id ]:
620+ # If a neighbor has not been visited yet, recursively visit it
621+ if not visited [neighbor_node_id ]:
622+ dfs (neighbor_node_id , component )
623+
624+ # Perform DFS on all nodes to find connected components
625+ for i in range (len (nodes )):
626+ # If a node has not been visited yet, start a new DFS from it
627+ if not visited [i ]:
628+ component = []
629+ dfs (i , component )
630+ # After visiting all reachable nodes, add the current component to the list
631+ connected_components .append (component )
632+ return connected_components
633+
634+
635+ def map_runtime_aot_intermediate_outputs (
636+ aot_intermediate_outputs : Dict [Tuple [int , ...], Any ],
637+ runtime_intermediate_outputs : Dict [Tuple [int , ...], Any ],
638+ ) -> Dict [Tuple [Tuple [int , ...], Any ], Tuple [Tuple [int , ...], Any ]]:
639+ """
640+ Map the runtime intermediate outputs to the AOT intermediate outputs
641+ by finding overlapping debug handles and combining them into a single debug_handle
642+
643+ Returns:
644+ Dict[Tuple[Tuple[int, ...], Any], Tuple[Tuple[int, ...], Any]] - Mapping
645+ from runtime intermediate output to AOT intermediate output
646+ """
647+ # Merge overlapping debug handles
648+ merge_overlapping_debug_handles (aot_intermediate_outputs )
649+ merge_overlapping_debug_handles (runtime_intermediate_outputs )
650+
651+ # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
652+ nodes , edges = _create_debug_handle_overlap_graph (
653+ aot_intermediate_outputs , runtime_intermediate_outputs
654+ )
655+ # Find connected(between aot and runtime) components
656+ connected_components = _find_connected_components (nodes , edges )
657+
658+ aot_runtime_mapping = {}
659+ for comp in connected_components :
660+ # Separate nodes into AOT and runtime lists based on their source,
661+ # each list is combined into a single element and mapped to each other.
662+ aot_list = [
663+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
664+ for node_id in comp
665+ if nodes [node_id ].source == FROM_AOT
666+ ]
667+ runtime_list = [
668+ (nodes [node_id ].debug_handle , nodes [node_id ].output )
669+ for node_id in comp
670+ if nodes [node_id ].source == FROM_RUNTIME
671+ ]
672+
673+ # Map only if both AOT and runtime data are present.
674+ if len (aot_list ) != 0 and len (runtime_list ) != 0 :
675+ # Combine aot debug handles into a single key
676+ aot_combined_debug_handle , aot_output = (
677+ _combine_overlapped_intermediate_outputs (aot_list )
678+ )
679+ # Combine runtime debug handles into a single key
680+ runtime_combined_debug_handle , runtime_output = (
681+ _combine_overlapped_intermediate_outputs (runtime_list )
682+ )
683+ # Create a mapping between runtime and aot
684+ aot_runtime_mapping [(aot_combined_debug_handle , aot_output )] = (
685+ runtime_combined_debug_handle ,
686+ runtime_output ,
687+ )
688+
689+ return aot_runtime_mapping
0 commit comments