@@ -612,16 +612,18 @@ def get_display_name(idx: int):
612
612
"""
613
613
if hasattr (trace , "idx2loop_id" ):
614
614
# FIXME: only keep me after it is stable. Just for compatibility.
615
- return f"L{ trace .idx2loop_id [idx ]} "
615
+ return f"L{ trace .idx2loop_id [idx ]} ( { idx } ) "
616
616
return f"L{ idx } "
617
617
618
618
# Add nodes and edges
619
619
edges = []
620
+ parents_record = {}
620
621
for i , parents in enumerate (trace .dag_parent ):
621
622
for parent in parents :
622
623
edges .append ((get_display_name (parent ), get_display_name (i )))
623
624
if len (parents ) == 0 :
624
625
G .add_node (get_display_name (i ))
626
+ parents_record [get_display_name (i )] = [get_display_name (parent ) for parent in parents ]
625
627
G .add_edges_from (edges )
626
628
627
629
# Check if G is a path (a single line)
@@ -658,27 +660,21 @@ def get_display_name(idx: int):
658
660
pos = {}
659
661
660
662
def parent_avg_pos (node ):
661
- id = int (node [1 :])
662
- parents = trace .dag_parent [id ]
663
-
664
- if not parents :
665
- return 0
666
-
667
- parent_nodes = [f"L{ p } " for p in parents ]
663
+ parent_nodes = parents_record .get (node , [])
668
664
parent_xs = [pos [p ][0 ] for p in parent_nodes if p in pos ]
669
665
return sum (parent_xs ) / len (parent_xs ) if parent_xs else 0
670
666
671
667
for lvl in sorted (layer_nodes ):
672
668
nodes = layer_nodes [lvl ]
673
669
# For root nodes, sort directly by index
674
- if lvl == 0 :
675
- sorted_nodes = sorted (nodes , key = lambda n : int (n [1 :]))
670
+ if lvl == min ( layer_nodes ) :
671
+ sorted_nodes = sorted (nodes , key = lambda n : int (n [1 :]. split ( " " )[ 0 ] ))
676
672
else :
677
673
# Sort by average parent x, so children are below their parents
678
674
sorted_nodes = sorted (nodes , key = parent_avg_pos )
679
675
y = - lvl # y decreases as level increases (children below parents)
680
676
for i , node in enumerate (sorted_nodes ):
681
- if lvl == 0 :
677
+ if lvl == min ( layer_nodes ) :
682
678
x = i
683
679
else :
684
680
# Place child directly below average parent x, offset if multiple at same y
0 commit comments