Skip to content

Commit 5cf40fc

Browse files
committed
added unit test
1 parent 5f88530 commit 5cf40fc

File tree

4 files changed

+92
-53
lines changed

4 files changed

+92
-53
lines changed

PathPlanning/TimeBasedPathPlanning/SpaceTimeAStar.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
"""
2+
Space-time A* Algorithm
3+
This script demonstrates the Space-time A* algorithm for path planning in a grid world with moving obstacles.
4+
5+
Reference: https://www.davidsilver.uk/wp-content/uploads/2020/03/coop-path-AIWisdom.pdf
6+
"""
7+
8+
from __future__ import annotations # For typehints of a class within itself
19
import numpy as np
210
import matplotlib.pyplot as plt
3-
import matplotlib.animation as animation
4-
from moving_obstacles import Grid, Position
11+
from PathPlanning.TimeBasedPathPlanning.moving_obstacles import Grid, ObstacleArrangement, Position
512
import heapq
6-
from typing import Generator
13+
from collections.abc import Generator
714
import random
8-
from __future__ import annotations
915

1016
# Seed randomness for reproducibility
1117
RANDOM_SEED = 50
@@ -23,7 +29,11 @@ def __init__(self, position: Position, time: int, heuristic: int, parent_index:
2329
self.time = time
2430
self.heuristic = heuristic
2531
self.parent_index = parent_index
26-
32+
33+
"""
34+
This is what is used to drive node expansion. The node with the lowest value is expanded next.
35+
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
36+
"""
2737
def __lt__(self, other: Node):
2838
return (self.time + self.heuristic) < (other.time + other.heuristic)
2939

@@ -32,25 +42,25 @@ def __repr__(self):
3242

3343
class NodePath:
3444
path: list[Node]
45+
positions_at_time: dict[int, Position] = {}
3546

3647
def __init__(self, path: list[Node]):
3748
self.path = path
38-
49+
for node in path:
50+
self.positions_at_time[node.time] = node.position
51+
52+
"""
53+
Get the position of the path at a given time
54+
"""
3955
def get_position(self, time: int) -> Position:
40-
# TODO: this is inefficient
41-
for i in range(0, len(self.path) - 2):
42-
if self.path[i + 1].time > time:
43-
print(f"position @ {i} is {self.path[i].position}")
44-
return self.path[i].position
45-
46-
if len(self.path) > 0:
47-
return self.path[-1].position
48-
49-
return None
50-
56+
return self.positions_at_time.get(time)
57+
58+
"""
59+
Time stamp of the last node in the path
60+
"""
5161
def goal_reached_time(self) -> int:
5262
return self.path[-1].time
53-
63+
5464
def __repr__(self):
5565
repr_string = ""
5666
for (i, node) in enumerate(self.path):
@@ -71,7 +81,6 @@ def plan(self, verbose: bool = False) -> NodePath:
7181
open_set = []
7282
heapq.heappush(open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1))
7383

74-
# TODO: is vec good here?
7584
expanded_set = []
7685
while open_set:
7786
expanded_node: Node = heapq.heappop(open_set)
@@ -92,7 +101,7 @@ def plan(self, verbose: bool = False) -> NodePath:
92101
path_walker = expanded_set[path_walker.parent_index]
93102
# TODO: fix hack around bad while condiiotn
94103
path.append(path_walker)
95-
104+
96105
# reverse path so it goes start -> goal
97106
path.reverse()
98107
return NodePath(path)
@@ -102,9 +111,12 @@ def plan(self, verbose: bool = False) -> NodePath:
102111

103112
for child in self.generate_successors(expanded_node, expanded_idx, verbose):
104113
heapq.heappush(open_set, child)
105-
114+
106115
raise Exception("No path found")
107-
116+
117+
"""
118+
Generate possible successors of the provided `parent_node`
119+
"""
108120
def generate_successors(self, parent_node: Node, parent_node_idx: int, verbose: bool) -> Generator[Node, None, None]:
109121
diffs = [Position(0, 1), Position(0, -1), Position(1, 0), Position(-1, 0), Position(0, 0)]
110122
for diff in diffs:
@@ -119,13 +131,12 @@ def calculate_heuristic(self, position) -> int:
119131
diff = self.goal - position
120132
return abs(diff.x) + abs(diff.y)
121133

122-
import imageio.v2 as imageio
123134
show_animation = True
124135
def main():
125-
start = Position(1, 1)
136+
start = Position(1, 11)
126137
goal = Position(19, 19)
127138
grid_side_length = 21
128-
grid = Grid(np.array([grid_side_length, grid_side_length]), num_obstacles=40, obstacle_avoid_points=[start, goal])
139+
grid = Grid(np.array([grid_side_length, grid_side_length]), num_obstacles=40, obstacle_avoid_points=[start, goal], obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1)
129140

130141
planner = TimeBasedAStar(grid, start, goal)
131142
verbose = False
@@ -144,7 +155,7 @@ def main():
144155
ax.set_xticks(np.arange(0, grid_side_length, 1))
145156
ax.set_yticks(np.arange(0, grid_side_length, 1))
146157

147-
start_and_goal, = ax.plot([], [], 'mD', ms=15, label="Start and Goal")
158+
start_and_goal, = ax.plot([], [], 'mD', ms=15, label="Start and Goal")
148159
start_and_goal.set_data([start.x, goal.x], [start.y, goal.y])
149160
obs_points, = ax.plot([], [], 'ro', ms=15, label="Obstacles")
150161
path_points, = ax.plot([], [], 'bo', ms=10, label="Path Found")
@@ -155,17 +166,13 @@ def main():
155166
lambda event: [exit(
156167
0) if event.key == 'escape' else None])
157168

158-
frames = []
159169
for i in range(0, path.goal_reached_time()):
160170
obs_positions = grid.get_obstacle_positions_at_time(i)
161171
obs_points.set_data(obs_positions[0], obs_positions[1])
162172
path_position = path.get_position(i)
163173
path_points.set_data([path_position.x], [path_position.y])
164174
plt.pause(0.2)
165-
plt.savefig(f"frame_{i:03d}.png") # Save each frame as an image
166-
frames.append(imageio.imread(f"frame_{i:03d}.png"))
167-
imageio.mimsave("path_animation.gif", frames, fps=5) # Convert images to GIF
168175
plt.show()
169176

170177
if __name__ == '__main__':
171-
main()
178+
main()

PathPlanning/TimeBasedPathPlanning/__init__.py

Whitespace-only changes.

PathPlanning/TimeBasedPathPlanning/moving_obstacles.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ def __init__(self, x: int, y: int):
1111

1212
def as_ndarray(self) -> np.ndarray[int, int]:
1313
return np.array([self.x, self.y])
14-
14+
1515
def __add__(self, other):
1616
if isinstance(other, Position):
1717
return Position(self.x + other.x, self.y + other.y)
1818
raise NotImplementedError(f"Addition not supported for Position and {type(other)}")
19-
19+
2020
def __sub__(self, other):
2121
if isinstance(other, Position):
2222
return Position(self.x - other.x, self.y - other.y)
@@ -36,24 +36,21 @@ class ObstacleArrangement(Enum):
3636
# Obstacles start in a line in y at center of grid and move side-to-side in x
3737
ARRANGEMENT1 = 1
3838

39-
class Grid():
40-
39+
class Grid:
4140
# Set in constructor
4241
grid_size = None
4342
grid = None
4443
obstacle_paths: list[list[Position]] = []
4544
# Obstacles will never occupy these points. Useful to avoid impossible scenarios
4645
obstacle_avoid_points = []
4746

48-
# Problem definition
4947
# Number of time steps in the simulation
5048
time_limit: int
5149

5250
# Logging control
5351
verbose = False
5452

55-
def __init__(self, grid_size: np.ndarray[int, int], num_obstacles: int = 2, obstacle_avoid_points: list[Position] = [], obstacle_arrangement: ObstacleArrangement = ObstacleArrangement.RANDOM, time_limit: int = 100):
56-
num_obstacles
53+
def __init__(self, grid_size: np.ndarray[int, int], num_obstacles: int = 40, obstacle_avoid_points: list[Position] = [], obstacle_arrangement: ObstacleArrangement = ObstacleArrangement.RANDOM, time_limit: int = 100):
5754
self.obstacle_avoid_points = obstacle_avoid_points
5855
self.time_limit = time_limit
5956
self.grid_size = grid_size
@@ -66,15 +63,18 @@ def __init__(self, grid_size: np.ndarray[int, int], num_obstacles: int = 2, obst
6663
self.obstacle_paths = self.generate_dynamic_obstacles(num_obstacles)
6764
elif obstacle_arrangement == ObstacleArrangement.ARRANGEMENT1:
6865
self.obstacle_paths = self.obstacle_arrangement_1(num_obstacles)
69-
66+
7067
for (i, path) in enumerate(self.obstacle_paths):
71-
obs_idx = i + 1 # avoid using 0 - that indicates free space
68+
obs_idx = i + 1 # avoid using 0 - that indicates free space in the grid
7269
for (t, position) in enumerate(path):
7370
# Reserve old & new position at this time step
7471
if t > 0:
7572
self.grid[path[t-1].x, path[t-1].y, t] = obs_idx
7673
self.grid[position.x, position.y, t] = obs_idx
7774

75+
"""
76+
Generate dynamic obstacles that move around the grid. Initial positions and movements are random
77+
"""
7878
def generate_dynamic_obstacles(self, obs_count: int) -> list[list[Position]]:
7979
obstacle_paths = []
8080
for _obs_idx in (0, obs_count):
@@ -112,31 +112,39 @@ def generate_dynamic_obstacles(self, obs_count: int) -> list[list[Position]]:
112112
valid_position = positions[-1]
113113

114114
positions.append(valid_position)
115-
115+
116116
obstacle_paths.append(positions)
117117

118118
return obstacle_paths
119-
119+
120+
"""
121+
Generate a line of obstacles in y at the center of the grid that move side-to-side in x
122+
Bottom half start moving right, top half start moving left. If `obs_count` is less than the length of
123+
the grid, only the first `obs_count` obstacles will be generated.
124+
"""
120125
def obstacle_arrangement_1(self, obs_count: int) -> list[list[Position]]:
121-
# bottom half of y values start left -> right
122-
# top half of y values start right -> left
123126
obstacle_paths = []
124127
half_grid_x = self.grid_size[0] // 2
125128
half_grid_y = self.grid_size[1] // 2
126-
127-
for y_idx in range(0, min(obs_count, self.grid_size[1] - 1)):
129+
130+
for y_idx in range(0, min(obs_count, self.grid_size[1])):
128131
moving_right = y_idx < half_grid_y
129132
position = Position(half_grid_x, y_idx)
130133
path = [position]
131134

132-
for _t in range(1, self.time_limit-1):
135+
for t in range(1, self.time_limit-1):
136+
# sit in place every other time step
137+
if t % 2 == 0:
138+
path.append(position)
139+
continue
140+
133141
# first check if we should switch direction (at edge of grid)
134142
if (moving_right and position.x == self.grid_size[0] - 1) or (not moving_right and position.x == 0):
135143
moving_right = not moving_right
136144
# step in direction
137145
position = Position(position.x + (1 if moving_right else -1), position.y)
138146
path.append(position)
139-
147+
140148
obstacle_paths.append(path)
141149

142150
return obstacle_paths
@@ -159,13 +167,13 @@ def valid_position(self, position: Position, t: int) -> bool:
159167

160168
# Check if new position is not occupied at time t
161169
return self.grid[position.x, position.y, t] == 0
162-
170+
163171
"""
164172
Returns True if the given position is valid at time t and is not in the set of obstacle_avoid_points
165173
"""
166174
def valid_obstacle_position(self, position: Position, t: int) -> bool:
167175
return self.valid_position(position, t) and position not in self.obstacle_avoid_points
168-
176+
169177
"""
170178
Returns True if the given position is within the grid's boundaries
171179
"""
@@ -180,12 +188,12 @@ def inside_grid_bounds(self, position: Position) -> bool:
180188
"""
181189
def sample_random_position(self) -> Position:
182190
return Position(np.random.randint(0, self.grid_size[0]), np.random.randint(0, self.grid_size[1]))
183-
191+
184192
"""
185193
Returns a tuple of (x_positions, y_positions) of the obstacles at time t
186194
"""
187195
def get_obstacle_positions_at_time(self, t: int) -> tuple[list[int], list[int]]:
188-
196+
189197
x_positions = []
190198
y_positions = []
191199
for obs_path in self.obstacle_paths:
@@ -221,4 +229,4 @@ def main():
221229
plt.show()
222230

223231
if __name__ == '__main__':
224-
main()
232+
main()

tests/test_space_time_astar.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from PathPlanning.TimeBasedPathPlanning.moving_obstacles import Grid, ObstacleArrangement, Position
2+
from PathPlanning.TimeBasedPathPlanning import SpaceTimeAStar as m
3+
import numpy as np
4+
import conftest
5+
6+
def test_1():
7+
start = Position(1, 11)
8+
goal = Position(19, 19)
9+
grid_side_length = 21
10+
grid = Grid(np.array([grid_side_length, grid_side_length]), obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1)
11+
12+
m.show_animation = False
13+
planner = m.TimeBasedAStar(grid, start, goal)
14+
15+
path = planner.plan(False)
16+
17+
# path should have 28 entries
18+
assert len(path.path) == 31
19+
20+
# path should end at the goal
21+
assert path.path[-1].position == goal
22+
23+
if __name__ == '__main__':
24+
conftest.run_this_test(__file__)

0 commit comments

Comments
 (0)