Skip to content

Commit a69f382

Browse files
committed
format STA* file
1 parent 15087e6 commit a69f382

File tree

1 file changed

+60
-19
lines changed

1 file changed

+60
-19
lines changed

PathPlanning/TimeBasedPathPlanning/SpaceTimeAStar.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
Reference: https://www.davidsilver.uk/wp-content/uploads/2020/03/coop-path-AIWisdom.pdf
66
"""
77

8-
from __future__ import annotations # For typehints of a class within itself
8+
from __future__ import annotations # For typehints of a class within itself
99
import numpy as np
1010
import matplotlib.pyplot as plt
11-
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import Grid, ObstacleArrangement, Position
11+
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import (
12+
Grid,
13+
ObstacleArrangement,
14+
Position,
15+
)
1216
import heapq
1317
from collections.abc import Generator
1418
import random
@@ -18,13 +22,16 @@
1822
random.seed(RANDOM_SEED)
1923
np.random.seed(RANDOM_SEED)
2024

25+
2126
class Node:
2227
position: Position
2328
time: int
2429
heuristic: int
2530
parent_index: int
2631

27-
def __init__(self, position: Position, time: int, heuristic: int, parent_index: int):
32+
def __init__(
33+
self, position: Position, time: int, heuristic: int, parent_index: int
34+
):
2835
self.position = position
2936
self.time = time
3037
self.heuristic = heuristic
@@ -34,12 +41,14 @@ def __init__(self, position: Position, time: int, heuristic: int, parent_index:
3441
This is what is used to drive node expansion. The node with the lowest value is expanded next.
3542
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
3643
"""
44+
3745
def __lt__(self, other: Node):
3846
return (self.time + self.heuristic) < (other.time + other.heuristic)
3947

4048
def __repr__(self):
4149
return f"Node(position={self.position}, time={self.time}, heuristic={self.heuristic}, parent_index={self.parent_index})"
4250

51+
4352
class NodePath:
4453
path: list[Node]
4554
positions_at_time: dict[int, Position] = {}
@@ -52,21 +61,24 @@ def __init__(self, path: list[Node]):
5261
"""
5362
Get the position of the path at a given time
5463
"""
64+
5565
def get_position(self, time: int) -> Position:
5666
return self.positions_at_time.get(time)
5767

5868
"""
5969
Time stamp of the last node in the path
6070
"""
71+
6172
def goal_reached_time(self) -> int:
6273
return self.path[-1].time
6374

6475
def __repr__(self):
6576
repr_string = ""
66-
for (i, node) in enumerate(self.path):
77+
for i, node in enumerate(self.path):
6778
repr_string += f"{i}: {node}\n"
6879
return repr_string
6980

81+
7082
class TimeBasedAStar:
7183
grid: Grid
7284
start: Position
@@ -79,7 +91,9 @@ def __init__(self, grid: Grid, start: Position, goal: Position):
7991

8092
def plan(self, verbose: bool = False) -> NodePath:
8193
open_set = []
82-
heapq.heappush(open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1))
94+
heapq.heappush(
95+
open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1)
96+
)
8397

8498
expanded_set = []
8599
while open_set:
@@ -117,12 +131,26 @@ def plan(self, verbose: bool = False) -> NodePath:
117131
"""
118132
Generate possible successors of the provided `parent_node`
119133
"""
120-
def generate_successors(self, parent_node: Node, parent_node_idx: int, verbose: bool) -> Generator[Node, None, None]:
121-
diffs = [Position(0, 1), Position(0, -1), Position(1, 0), Position(-1, 0), Position(0, 0)]
134+
135+
def generate_successors(
136+
self, parent_node: Node, parent_node_idx: int, verbose: bool
137+
) -> Generator[Node, None, None]:
138+
diffs = [
139+
Position(0, 1),
140+
Position(0, -1),
141+
Position(1, 0),
142+
Position(-1, 0),
143+
Position(0, 0),
144+
]
122145
for diff in diffs:
123146
new_pos = parent_node.position + diff
124-
if self.grid.valid_position(new_pos, parent_node.time+1):
125-
new_node = Node(new_pos, parent_node.time+1, self.calculate_heuristic(new_pos), parent_node_idx)
147+
if self.grid.valid_position(new_pos, parent_node.time + 1):
148+
new_node = Node(
149+
new_pos,
150+
parent_node.time + 1,
151+
self.calculate_heuristic(new_pos),
152+
parent_node_idx,
153+
)
126154
if verbose:
127155
print("\tNew successor node: ", new_node)
128156
yield new_node
@@ -131,12 +159,20 @@ def calculate_heuristic(self, position) -> int:
131159
diff = self.goal - position
132160
return abs(diff.x) + abs(diff.y)
133161

162+
134163
show_animation = True
164+
165+
135166
def main():
136167
start = Position(1, 11)
137168
goal = Position(19, 19)
138169
grid_side_length = 21
139-
grid = Grid(np.array([grid_side_length, grid_side_length]), num_obstacles=40, obstacle_avoid_points=[start, goal], obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1)
170+
grid = Grid(
171+
np.array([grid_side_length, grid_side_length]),
172+
num_obstacles=40,
173+
obstacle_avoid_points=[start, goal],
174+
obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1,
175+
)
140176

141177
planner = TimeBasedAStar(grid, start, goal)
142178
verbose = False
@@ -149,22 +185,26 @@ def main():
149185
return
150186

151187
fig = plt.figure(figsize=(10, 7))
152-
ax = fig.add_subplot(autoscale_on=False, xlim=(0, grid.grid_size[0]-1), ylim=(0, grid.grid_size[1]-1))
153-
ax.set_aspect('equal')
188+
ax = fig.add_subplot(
189+
autoscale_on=False,
190+
xlim=(0, grid.grid_size[0] - 1),
191+
ylim=(0, grid.grid_size[1] - 1),
192+
)
193+
ax.set_aspect("equal")
154194
ax.grid()
155195
ax.set_xticks(np.arange(0, grid_side_length, 1))
156196
ax.set_yticks(np.arange(0, grid_side_length, 1))
157197

158-
start_and_goal, = ax.plot([], [], 'mD', ms=15, label="Start and Goal")
198+
(start_and_goal,) = ax.plot([], [], "mD", ms=15, label="Start and Goal")
159199
start_and_goal.set_data([start.x, goal.x], [start.y, goal.y])
160-
obs_points, = ax.plot([], [], 'ro', ms=15, label="Obstacles")
161-
path_points, = ax.plot([], [], 'bo', ms=10, label="Path Found")
200+
(obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles")
201+
(path_points,) = ax.plot([], [], "bo", ms=10, label="Path Found")
162202
ax.legend(bbox_to_anchor=(1.05, 1))
163203

164204
# for stopping simulation with the esc key.
165-
plt.gcf().canvas.mpl_connect('key_release_event',
166-
lambda event: [exit(
167-
0) if event.key == 'escape' else None])
205+
plt.gcf().canvas.mpl_connect(
206+
"key_release_event", lambda event: [exit(0) if event.key == "escape" else None]
207+
)
168208

169209
for i in range(0, path.goal_reached_time()):
170210
obs_positions = grid.get_obstacle_positions_at_time(i)
@@ -174,5 +214,6 @@ def main():
174214
plt.pause(0.2)
175215
plt.show()
176216

177-
if __name__ == '__main__':
217+
218+
if __name__ == "__main__":
178219
main()

0 commit comments

Comments
 (0)