Skip to content

Commit 6c58b5f

Browse files
committed
constraint tree visualizer
1 parent 2a1abb2 commit 6c58b5f

File tree

2 files changed

+231
-10
lines changed

2 files changed

+231
-10
lines changed

PathPlanning/TimeBasedPathPlanning/ConflictBasedSearch.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
from PathPlanning.TimeBasedPathPlanning.ConstraintTree import AgentId, AppliedConstraint, ConstraintTree, ConstraintTreeNode, ForkingConstraint
2222
import time
2323

24+
# TODO: dont include this
25+
from constraint_tree_viz import visualize_cbs_tree
2426
class ConflictBasedSearch(MultiAgentPlanner):
2527

28+
29+
# TODO: remove ConstraintTree from return
2630
@staticmethod
27-
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
31+
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath], ConstraintTree]:
2832
"""
2933
Generate a path from the start to the goal for each agent in the `start_and_goals` list.
3034
Returns the re-ordered StartAndGoal combinations, and a list of path plans. The order of the plans
@@ -97,7 +101,7 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
97101
for constraint in all_constraints:
98102
print(f"\t{constraint}")
99103
print(f"Final cost: {constraint_tree_node.cost}")
100-
return (start_and_goals, paths.values())
104+
return (start_and_goals, paths.values(), constraint_tree)
101105

102106
if verbose:
103107
print(f"Adding new constraint tree node with constraint: {new_constraint_tree_node.constraint}")
@@ -161,7 +165,7 @@ def main():
161165
grid_side_length = 21
162166

163167
# start_and_goals = [StartAndGoal(i, Position(1, i), Position(19, 19-i)) for i in range(1, 12)]
164-
start_and_goals = [StartAndGoal(i, Position(1, 8+i), Position(19, 19-i)) for i in range(5)]
168+
# start_and_goals = [StartAndGoal(i, Position(1, 8+i), Position(19, 19-i)) for i in range(5)]
165169
# start_and_goals = [StartAndGoal(i, Position(1, 2*i), Position(19, 19-i)) for i in range(4)]
166170

167171
# generate random start and goals
@@ -176,9 +180,9 @@ def main():
176180
# start_and_goals.append(StartAndGoal(i, start, goal))
177181

178182
# hallway cross
179-
# start_and_goals = [StartAndGoal(0, Position(6, 10), Position(13, 10)),
180-
# StartAndGoal(1, Position(11, 10), Position(6, 10)),
181-
# StartAndGoal(2, Position(13, 10), Position(7, 10))]
183+
start_and_goals = [StartAndGoal(0, Position(6, 10), Position(13, 10)),
184+
StartAndGoal(1, Position(11, 10), Position(6, 10)),
185+
StartAndGoal(2, Position(13, 10), Position(7, 10))]
182186

183187
# temporary obstacle
184188
# start_and_goals = [StartAndGoal(0, Position(15, 14), Position(15, 16))]
@@ -192,15 +196,15 @@ def main():
192196
num_obstacles=250,
193197
obstacle_avoid_points=obstacle_avoid_points,
194198
# obstacle_arrangement=ObstacleArrangement.TEMPORARY_OBSTACLE,
195-
# obstacle_arrangement=ObstacleArrangement.HALLWAY,
196-
obstacle_arrangement=ObstacleArrangement.NARROW_CORRIDOR,
199+
obstacle_arrangement=ObstacleArrangement.HALLWAY,
200+
# obstacle_arrangement=ObstacleArrangement.NARROW_CORRIDOR,
197201
# obstacle_arrangement=ObstacleArrangement.NONE,
198202
# obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1,
199203
# obstacle_arrangement=ObstacleArrangement.RANDOM,
200204
)
201205

202206
start_time = time.time()
203-
start_and_goals, paths = ConflictBasedSearch.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
207+
start_and_goals, paths, constraint_tree = ConflictBasedSearch.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
204208
# start_and_goals, paths = ConflictBasedSearch.plan(grid, start_and_goals, SpaceTimeAStar, verbose)
205209

206210
runtime = time.time() - start_time
@@ -214,7 +218,8 @@ def main():
214218
if not show_animation:
215219
return
216220

217-
PlotNodePaths(grid, start_and_goals, paths)
221+
visualize_cbs_tree(constraint_tree.expanded_nodes, constraint_tree.nodes_to_expand)
222+
# PlotNodePaths(grid, start_and_goals, paths)
218223

219224
if __name__ == "__main__":
220225
main()
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import plotly.graph_objects as go
2+
import plotly.express as px
3+
from plotly.subplots import make_subplots
4+
import networkx as nx
5+
from typing import Optional, Dict, Any
6+
from dataclasses import dataclass
7+
from PathPlanning.TimeBasedPathPlanning.ConstraintTree import AgentId, AppliedConstraint, ConstraintTree, ConstraintTreeNode, ForkingConstraint
8+
9+
def visualize_cbs_tree(
10+
expanded_nodes: Dict[int, ConstraintTreeNode],
11+
nodes_to_expand: list[ConstraintTreeNode],
12+
initial_size: int = 15
13+
) -> None:
14+
"""
15+
Visualize the CBS constraint tree with interactive nodes.
16+
Click a node to print its details to console.
17+
"""
18+
19+
# Build networkx graph
20+
G = nx.DiGraph()
21+
22+
# Add all nodes with metadata
23+
node_colors = []
24+
node_sizes = []
25+
node_labels = {}
26+
27+
for idx, node in expanded_nodes.items():
28+
G.add_node(idx)
29+
node_labels[idx] = f"<b>Node {idx}</b><br>Cost: {node.cost}<br>Parent: {node.parent_idx}<br>Constraint:<br>{node.constraint}"
30+
node_colors.append("lightblue")
31+
node_sizes.append(initial_size)
32+
33+
# Add edge from parent
34+
# if node.parent_idx is not None:
35+
if node.parent_idx is not None and node.parent_idx in expanded_nodes:
36+
G.add_edge(node.parent_idx, idx)
37+
# G.add_edge(0, 5)
38+
print(f"adding edge btwn {node.parent_idx} and {idx}")
39+
40+
# Add unexpanded nodes
41+
# unexpanded_node_map = {}
42+
# for node in nodes_to_expand:
43+
# idx = id(node) # Use object id for heap nodes
44+
# if idx not in G.nodes():
45+
# G.add_node(idx)
46+
# # node_labels[idx] = f"Node {idx}\n(cost: {node.cost})"
47+
# node_labels[idx] = f"<b>Node {idx}</b><br>Cost: {node.cost}<br>Constraint:<br>{node.constraint}"
48+
# node_colors.append("lightyellow")
49+
# node_sizes.append(initial_size)
50+
# unexpanded_node_map[idx] = node
51+
52+
# if node.parent_idx is not None and node.parent_idx >= 0:
53+
# G.add_edge(node.parent_idx, idx)
54+
55+
# Use hierarchical layout with fixed horizontal spacing
56+
pos = _hierarchy_pos(G, root=next(iter(G.nodes()), None), vert_gap=0.3, horiz_gap=1.5)
57+
58+
# Extract edge coordinates
59+
edge_x = []
60+
edge_y = []
61+
for edge in G.edges():
62+
print(f"Drawing edge: {edge}")
63+
if edge[0] in pos and edge[1] in pos:
64+
x0, y0 = pos[edge[0]]
65+
x1, y1 = pos[edge[1]]
66+
edge_x.extend([x0, x1, None])
67+
edge_y.extend([y0, y1, None])
68+
else:
69+
edge_x.extend([1, 1, None])
70+
edge_y.extend([5, 5, None])
71+
72+
# Extract node coordinates
73+
node_x = []
74+
node_y = []
75+
for node in G.nodes():
76+
x, y = 1, 1
77+
if node in pos:
78+
x, y = pos[node]
79+
node_x.append(x)
80+
node_y.append(y)
81+
82+
# Create figure
83+
fig = go.Figure()
84+
85+
# Add edges
86+
fig.add_trace(go.Scatter(
87+
x=edge_x, y=edge_y,
88+
mode='lines',
89+
line=dict(width=2, color='#888'),
90+
hoverinfo='none',
91+
showlegend=False
92+
))
93+
# Add nodes
94+
fig.add_trace(go.Scatter(
95+
x=node_x, y=node_y,
96+
mode='markers',
97+
marker=dict(
98+
size=node_sizes,
99+
color=node_colors,
100+
line=dict(width=2, color='darkblue')
101+
),
102+
text=[node_labels[node] for node in G.nodes() if node in node_labels],
103+
hoverinfo='text',
104+
showlegend=False,
105+
customdata=list(G.nodes())
106+
))
107+
108+
fig.update_layout(
109+
title="CBS Constraint Tree",
110+
showlegend=False,
111+
hovermode='closest',
112+
margin=dict(b=20, l=5, r=5, t=40),
113+
xaxis=dict(
114+
showgrid=False,
115+
zeroline=False,
116+
showticklabels=False,
117+
scaleanchor="y",
118+
scaleratio=1
119+
),
120+
yaxis=dict(
121+
showgrid=False,
122+
zeroline=False,
123+
showticklabels=False,
124+
scaleanchor="x",
125+
scaleratio=1
126+
),
127+
plot_bgcolor='white',
128+
autosize=True,
129+
)
130+
131+
# Add click event
132+
fig.update_traces(
133+
selector=dict(mode='markers'),
134+
customdata=list(G.nodes()),
135+
hovertemplate='%{text}<extra></extra>'
136+
)
137+
138+
fig.update_xaxes(fixedrange=False)
139+
fig.update_yaxes(fixedrange=False)
140+
141+
# Show and set up click handler
142+
fig.show()
143+
144+
# Print handler instructions
145+
print("\nCBS Tree Visualization")
146+
print("=" * 50)
147+
print("Hover over nodes to see cost")
148+
print("Right-click → 'Inspect' → Open browser console")
149+
print("Then paste this to get node info:\n")
150+
print("for (let node of document.querySelectorAll('circle')) {")
151+
print(" node.onclick = (e) => {")
152+
print(" console.log('Clicked node:', e.target);")
153+
print(" }")
154+
print("}\n")
155+
print("Or use the alternative: Print all nodes programmatically:\n")
156+
157+
def _hierarchy_pos(G, root=None, vert_gap=0.2, horiz_gap=1.0, xcenter=0.5):
158+
"""
159+
Create hierarchical layout for tree visualization with fixed horizontal spacing.
160+
"""
161+
if not nx.is_tree(G):
162+
G = nx.DiGraph(G)
163+
164+
def _hierarchy_pos_recursive(G, root, vert_gap=0.2, horiz_gap=1.0, xcenter=0.5, pos=None, parent=None, child_index=0):
165+
if pos is None:
166+
pos = {root: (xcenter, 0)}
167+
else:
168+
pos[root] = (xcenter, pos[parent][1] - vert_gap)
169+
170+
neighbors = list(G.neighbors(root))
171+
172+
if len(neighbors) != 0:
173+
num_children = len(neighbors)
174+
# Spread children horizontally with fixed gap
175+
start_x = xcenter - (num_children - 1) * horiz_gap / 2
176+
for i, neighbor in enumerate(neighbors):
177+
nextx = start_x + i * horiz_gap
178+
_hierarchy_pos_recursive(G, neighbor, vert_gap=vert_gap, horiz_gap=horiz_gap,
179+
xcenter=nextx, pos=pos, parent=root, child_index=i)
180+
181+
return pos
182+
183+
return _hierarchy_pos_recursive(G, root, vert_gap, horiz_gap, xcenter)
184+
185+
186+
# Example usage:
187+
if __name__ == "__main__":
188+
from dataclasses import dataclass
189+
from typing import Optional
190+
191+
@dataclass
192+
class MockConstraint:
193+
agent: int
194+
time: int
195+
location: tuple
196+
197+
def __repr__(self):
198+
return f"Constraint(agent={self.agent}, t={self.time}, loc={self.location})"
199+
200+
@dataclass
201+
class MockNode:
202+
parent_idx: Optional[int]
203+
constraint: Optional[MockConstraint]
204+
paths: dict
205+
cost: int
206+
207+
# Create mock tree
208+
nodes = {
209+
0: MockNode(None, None, {"a": [], "b": []}, 10),
210+
1: MockNode(0, MockConstraint(0, 2, (0, 0)), {"a": [(0,0), (1,0)], "b": [(0,1), (0,2)]}, 12),
211+
2: MockNode(0, MockConstraint(1, 1, (0,1)), {"a": [(0,0), (1,0)], "b": [(0,1), (0,2)]}, 11),
212+
3: MockNode(1, MockConstraint(0, 3, (1,0)), {"a": [(0,0), (1,0), (1,1)], "b": [(0,1), (0,2)]}, 14),
213+
4: MockNode(2, None, {"a": [(0,0), (1,0)], "b": [(0,1), (1,1)]}, 12),
214+
}
215+
216+
visualize_cbs_tree(nodes, [])

0 commit comments

Comments
 (0)