1+ import plotly .graph_objects as go
2+ import plotly .express as px
3+ from plotly .subplots import make_subplots
4+ import networkx as nx
5+ from typing import Optional , Dict , Any
6+ from dataclasses import dataclass
7+ from PathPlanning .TimeBasedPathPlanning .ConstraintTree import AgentId , AppliedConstraint , ConstraintTree , ConstraintTreeNode , ForkingConstraint
8+
9+ def visualize_cbs_tree (
10+ expanded_nodes : Dict [int , ConstraintTreeNode ],
11+ nodes_to_expand : list [ConstraintTreeNode ],
12+ initial_size : int = 15
13+ ) -> None :
14+ """
15+ Visualize the CBS constraint tree with interactive nodes.
16+ Click a node to print its details to console.
17+ """
18+
19+ # Build networkx graph
20+ G = nx .DiGraph ()
21+
22+ # Add all nodes with metadata
23+ node_colors = []
24+ node_sizes = []
25+ node_labels = {}
26+
27+ for idx , node in expanded_nodes .items ():
28+ G .add_node (idx )
29+ node_labels [idx ] = f"<b>Node { idx } </b><br>Cost: { node .cost } <br>Parent: { node .parent_idx } <br>Constraint:<br>{ node .constraint } "
30+ node_colors .append ("lightblue" )
31+ node_sizes .append (initial_size )
32+
33+ # Add edge from parent
34+ # if node.parent_idx is not None:
35+ if node .parent_idx is not None and node .parent_idx in expanded_nodes :
36+ G .add_edge (node .parent_idx , idx )
37+ # G.add_edge(0, 5)
38+ print (f"adding edge btwn { node .parent_idx } and { idx } " )
39+
40+ # Add unexpanded nodes
41+ # unexpanded_node_map = {}
42+ # for node in nodes_to_expand:
43+ # idx = id(node) # Use object id for heap nodes
44+ # if idx not in G.nodes():
45+ # G.add_node(idx)
46+ # # node_labels[idx] = f"Node {idx}\n(cost: {node.cost})"
47+ # node_labels[idx] = f"<b>Node {idx}</b><br>Cost: {node.cost}<br>Constraint:<br>{node.constraint}"
48+ # node_colors.append("lightyellow")
49+ # node_sizes.append(initial_size)
50+ # unexpanded_node_map[idx] = node
51+
52+ # if node.parent_idx is not None and node.parent_idx >= 0:
53+ # G.add_edge(node.parent_idx, idx)
54+
55+ # Use hierarchical layout with fixed horizontal spacing
56+ pos = _hierarchy_pos (G , root = next (iter (G .nodes ()), None ), vert_gap = 0.3 , horiz_gap = 1.5 )
57+
58+ # Extract edge coordinates
59+ edge_x = []
60+ edge_y = []
61+ for edge in G .edges ():
62+ print (f"Drawing edge: { edge } " )
63+ if edge [0 ] in pos and edge [1 ] in pos :
64+ x0 , y0 = pos [edge [0 ]]
65+ x1 , y1 = pos [edge [1 ]]
66+ edge_x .extend ([x0 , x1 , None ])
67+ edge_y .extend ([y0 , y1 , None ])
68+ else :
69+ edge_x .extend ([1 , 1 , None ])
70+ edge_y .extend ([5 , 5 , None ])
71+
72+ # Extract node coordinates
73+ node_x = []
74+ node_y = []
75+ for node in G .nodes ():
76+ x , y = 1 , 1
77+ if node in pos :
78+ x , y = pos [node ]
79+ node_x .append (x )
80+ node_y .append (y )
81+
82+ # Create figure
83+ fig = go .Figure ()
84+
85+ # Add edges
86+ fig .add_trace (go .Scatter (
87+ x = edge_x , y = edge_y ,
88+ mode = 'lines' ,
89+ line = dict (width = 2 , color = '#888' ),
90+ hoverinfo = 'none' ,
91+ showlegend = False
92+ ))
93+ # Add nodes
94+ fig .add_trace (go .Scatter (
95+ x = node_x , y = node_y ,
96+ mode = 'markers' ,
97+ marker = dict (
98+ size = node_sizes ,
99+ color = node_colors ,
100+ line = dict (width = 2 , color = 'darkblue' )
101+ ),
102+ text = [node_labels [node ] for node in G .nodes () if node in node_labels ],
103+ hoverinfo = 'text' ,
104+ showlegend = False ,
105+ customdata = list (G .nodes ())
106+ ))
107+
108+ fig .update_layout (
109+ title = "CBS Constraint Tree" ,
110+ showlegend = False ,
111+ hovermode = 'closest' ,
112+ margin = dict (b = 20 , l = 5 , r = 5 , t = 40 ),
113+ xaxis = dict (
114+ showgrid = False ,
115+ zeroline = False ,
116+ showticklabels = False ,
117+ scaleanchor = "y" ,
118+ scaleratio = 1
119+ ),
120+ yaxis = dict (
121+ showgrid = False ,
122+ zeroline = False ,
123+ showticklabels = False ,
124+ scaleanchor = "x" ,
125+ scaleratio = 1
126+ ),
127+ plot_bgcolor = 'white' ,
128+ autosize = True ,
129+ )
130+
131+ # Add click event
132+ fig .update_traces (
133+ selector = dict (mode = 'markers' ),
134+ customdata = list (G .nodes ()),
135+ hovertemplate = '%{text}<extra></extra>'
136+ )
137+
138+ fig .update_xaxes (fixedrange = False )
139+ fig .update_yaxes (fixedrange = False )
140+
141+ # Show and set up click handler
142+ fig .show ()
143+
144+ # Print handler instructions
145+ print ("\n CBS Tree Visualization" )
146+ print ("=" * 50 )
147+ print ("Hover over nodes to see cost" )
148+ print ("Right-click → 'Inspect' → Open browser console" )
149+ print ("Then paste this to get node info:\n " )
150+ print ("for (let node of document.querySelectorAll('circle')) {" )
151+ print (" node.onclick = (e) => {" )
152+ print (" console.log('Clicked node:', e.target);" )
153+ print (" }" )
154+ print ("}\n " )
155+ print ("Or use the alternative: Print all nodes programmatically:\n " )
156+
157+ def _hierarchy_pos (G , root = None , vert_gap = 0.2 , horiz_gap = 1.0 , xcenter = 0.5 ):
158+ """
159+ Create hierarchical layout for tree visualization with fixed horizontal spacing.
160+ """
161+ if not nx .is_tree (G ):
162+ G = nx .DiGraph (G )
163+
164+ def _hierarchy_pos_recursive (G , root , vert_gap = 0.2 , horiz_gap = 1.0 , xcenter = 0.5 , pos = None , parent = None , child_index = 0 ):
165+ if pos is None :
166+ pos = {root : (xcenter , 0 )}
167+ else :
168+ pos [root ] = (xcenter , pos [parent ][1 ] - vert_gap )
169+
170+ neighbors = list (G .neighbors (root ))
171+
172+ if len (neighbors ) != 0 :
173+ num_children = len (neighbors )
174+ # Spread children horizontally with fixed gap
175+ start_x = xcenter - (num_children - 1 ) * horiz_gap / 2
176+ for i , neighbor in enumerate (neighbors ):
177+ nextx = start_x + i * horiz_gap
178+ _hierarchy_pos_recursive (G , neighbor , vert_gap = vert_gap , horiz_gap = horiz_gap ,
179+ xcenter = nextx , pos = pos , parent = root , child_index = i )
180+
181+ return pos
182+
183+ return _hierarchy_pos_recursive (G , root , vert_gap , horiz_gap , xcenter )
184+
185+
186+ # Example usage:
187+ if __name__ == "__main__" :
188+ from dataclasses import dataclass
189+ from typing import Optional
190+
191+ @dataclass
192+ class MockConstraint :
193+ agent : int
194+ time : int
195+ location : tuple
196+
197+ def __repr__ (self ):
198+ return f"Constraint(agent={ self .agent } , t={ self .time } , loc={ self .location } )"
199+
200+ @dataclass
201+ class MockNode :
202+ parent_idx : Optional [int ]
203+ constraint : Optional [MockConstraint ]
204+ paths : dict
205+ cost : int
206+
207+ # Create mock tree
208+ nodes = {
209+ 0 : MockNode (None , None , {"a" : [], "b" : []}, 10 ),
210+ 1 : MockNode (0 , MockConstraint (0 , 2 , (0 , 0 )), {"a" : [(0 ,0 ), (1 ,0 )], "b" : [(0 ,1 ), (0 ,2 )]}, 12 ),
211+ 2 : MockNode (0 , MockConstraint (1 , 1 , (0 ,1 )), {"a" : [(0 ,0 ), (1 ,0 )], "b" : [(0 ,1 ), (0 ,2 )]}, 11 ),
212+ 3 : MockNode (1 , MockConstraint (0 , 3 , (1 ,0 )), {"a" : [(0 ,0 ), (1 ,0 ), (1 ,1 )], "b" : [(0 ,1 ), (0 ,2 )]}, 14 ),
213+ 4 : MockNode (2 , None , {"a" : [(0 ,0 ), (1 ,0 )], "b" : [(0 ,1 ), (1 ,1 )]}, 12 ),
214+ }
215+
216+ visualize_cbs_tree (nodes , [])
0 commit comments