Skip to content

Commit 92689ed

Browse files
committed
fix tests
1 parent 8f63c88 commit 92689ed

File tree

8 files changed

+27
-18
lines changed

8 files changed

+27
-18
lines changed

PathPlanning/TimeBasedPathPlanning/BaseClasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class MultiAgentPlanner(ABC):
4242

4343
@staticmethod
4444
@abstractmethod
45-
def plan(grid: Grid, start_and_goal_positions: list[StartAndGoal], verbose: bool = False) -> list[NodePath]:
45+
def plan(grid: Grid, start_and_goal_positions: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
4646
"""
47-
Plan for all agents. Returned paths are in the order of the `StartAndGoal` list this object was instantiated with
47+
Plan for all agents. Returned paths are in order corresponding to the returned list of `StartAndGoal` objects
4848
"""
4949
pass

PathPlanning/TimeBasedPathPlanning/Node.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
from functools import total_ordering
33
import numpy as np
4+
from typing import Sequence
45

56
@dataclass(order=True)
67
class Position:
@@ -61,11 +62,15 @@ def __hash__(self):
6162
return hash((self.position, self.time))
6263

6364
class NodePath:
64-
path: list[Node]
65+
path: Sequence[Node]
6566
positions_at_time: dict[int, Position]
67+
# Number of nodes expanded while finding this path
68+
expanded_node_count: int
6669

67-
def __init__(self, path: list[Node]):
70+
def __init__(self, path: Sequence[Node], expanded_node_count: int):
6871
self.path = path
72+
self.expanded_node_count = expanded_node_count
73+
6974
self.positions_at_time = {}
7075
for i, node in enumerate(path):
7176
reservation_finish_time = node.time + 1

PathPlanning/TimeBasedPathPlanning/Plotting.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import matplotlib.pyplot as plt
3+
from matplotlib.backend_bases import KeyEvent
34
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import (
45
Grid,
56
Position,
@@ -30,13 +31,18 @@ def PlotNodePath(grid: Grid, start: Position, goal: Position, path: NodePath):
3031

3132
# for stopping simulation with the esc key.
3233
plt.gcf().canvas.mpl_connect(
33-
"key_release_event", lambda event: [exit(0) if event.key == "escape" else None]
34+
"key_release_event",
35+
lambda event: [exit(0) if event.key == "escape" else None]
36+
if isinstance(event, KeyEvent) else None
3437
)
3538

3639
for i in range(0, path.goal_reached_time()):
3740
obs_positions = grid.get_obstacle_positions_at_time(i)
3841
obs_points.set_data(obs_positions[0], obs_positions[1])
3942
path_position = path.get_position(i)
43+
if not path_position:
44+
raise Exception(f"Path position not found for time {i}.")
45+
4046
path_points.set_data([path_position.x], [path_position.y])
4147
plt.pause(0.2)
4248
plt.show()
@@ -91,7 +97,9 @@ def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[N
9197

9298
# For stopping simulation with the esc key
9399
plt.gcf().canvas.mpl_connect(
94-
"key_release_event", lambda event: [exit(0) if event.key == "escape" else None]
100+
"key_release_event",
101+
lambda event: [exit(0) if event.key == "escape" else None]
102+
if isinstance(event, KeyEvent) else None
95103
)
96104

97105
# Find the maximum time across all paths
@@ -112,6 +120,8 @@ def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[N
112120
print(path)
113121
print(i)
114122
path_position = path.get_position(i)
123+
if not path_position:
124+
raise Exception(f"Path position not found for time {i}.")
115125

116126
# Verify position is valid
117127
assert not path_position in obs_positions

PathPlanning/TimeBasedPathPlanning/PriorityBasedPlanner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class PriorityBasedPlanner(MultiAgentPlanner):
2424

2525
@staticmethod
26-
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool) -> tuple[list[StartAndGoal], list[NodePath]]:
26+
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
2727
"""
2828
Generate a path from the start to the goal for each agent in the `start_and_goals` list.
2929
Returns the re-ordered StartAndGoal combinations, and a list of path plans. The order of the plans
@@ -76,8 +76,6 @@ def main():
7676
)
7777

7878
start_time = time.time()
79-
start_and_goals: list[StartAndGoal]
80-
paths: list[NodePath]
8179
start_and_goals, paths = PriorityBasedPlanner.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
8280

8381
runtime = time.time() - start_time

PathPlanning/TimeBasedPathPlanning/SafeInterval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def plan(grid: Grid, start: Position, goal: Position, verbose: bool = False) ->
8181

8282
# reverse path so it goes start -> goal
8383
path.reverse()
84-
return NodePath(path)
84+
return NodePath(path, len(expanded_list))
8585

8686
expanded_idx = len(expanded_list)
8787
expanded_list.append(expanded_node)

PathPlanning/TimeBasedPathPlanning/SpaceTimeAStar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def plan(grid: Grid, start: Position, goal: Position, verbose: bool = False) ->
5454

5555
# reverse path so it goes start -> goal
5656
path.reverse()
57-
return NodePath(path)
57+
return NodePath(path, len(expanded_set))
5858

5959
expanded_idx = len(expanded_list)
6060
expanded_list.append(expanded_node)

tests/test_safe_interval_path_planner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def test_1():
1818
)
1919

2020
m.show_animation = False
21-
planner = m.SafeIntervalPathPlanner(grid, start, goal)
22-
23-
path = planner.plan(False)
21+
path = m.SafeIntervalPathPlanner.plan(grid, start, goal)
2422

2523
# path should have 31 entries
2624
assert len(path.path) == 31

tests/test_space_time_astar.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,15 @@ def test_1():
1818
)
1919

2020
m.show_animation = False
21-
planner = m.SpaceTimeAStar(grid, start, goal)
22-
23-
path = planner.plan(False)
21+
path = m.SpaceTimeAStar.plan(grid, start, goal)
2422

2523
# path should have 28 entries
2624
assert len(path.path) == 31
2725

2826
# path should end at the goal
2927
assert path.path[-1].position == goal
3028

31-
assert planner.expanded_node_count < 1000
29+
assert path.expanded_node_count < 1000
3230

3331
if __name__ == "__main__":
3432
conftest.run_this_test(__file__)

0 commit comments

Comments
 (0)