Skip to content

Commit d918947

Browse files
authored
Collaborative astar (#1247)
* consolidate Node definition * add base class for single agent planner * add base class for single agent planner * its working * use single agent plotting util * cleanup, bug fix, add some results to docs * remove seeding from sta* - it happens in Node * remove stale todo * rename CA* and speed up plotting * paper * proper paper (ofc its csail) * some cleanup * update docs * add unit test * add logic for saving animation as gif * address github bot * Revert "add logic for saving animation as gif" This reverts commit 6391677. * fix tests * docs lint * add gifs * copilot review * appease mypy
1 parent e9c86ab commit d918947

File tree

11 files changed

+566
-288
lines changed

11 files changed

+566
-288
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import (
4+
Grid,
5+
Position,
6+
)
7+
from PathPlanning.TimeBasedPathPlanning.Node import NodePath
8+
import random
9+
import numpy.random as numpy_random
10+
11+
# Seed randomness for reproducibility
12+
RANDOM_SEED = 50
13+
random.seed(RANDOM_SEED)
14+
numpy_random.seed(RANDOM_SEED)
15+
16+
class SingleAgentPlanner(ABC):
17+
"""
18+
Base class for single agent planners
19+
"""
20+
21+
@staticmethod
22+
@abstractmethod
23+
def plan(grid: Grid, start: Position, goal: Position, verbose: bool = False) -> NodePath:
24+
pass
25+
26+
@dataclass
27+
class StartAndGoal:
28+
# Index of this agent
29+
index: int
30+
# Start position of the robot
31+
start: Position
32+
# Goal position of the robot
33+
goal: Position
34+
35+
def distance_start_to_goal(self) -> float:
36+
return pow(self.goal.x - self.start.x, 2) + pow(self.goal.y - self.start.y, 2)
37+
38+
class MultiAgentPlanner(ABC):
39+
"""
40+
Base class for multi-agent planners
41+
"""
42+
43+
@staticmethod
44+
@abstractmethod
45+
def plan(grid: Grid, start_and_goal_positions: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
46+
"""
47+
Plan for all agents. Returned paths are in order corresponding to the returned list of `StartAndGoal` objects
48+
"""
49+
pass

PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,7 @@
77
import matplotlib.pyplot as plt
88
from enum import Enum
99
from dataclasses import dataclass
10-
11-
@dataclass(order=True)
12-
class Position:
13-
x: int
14-
y: int
15-
16-
def as_ndarray(self) -> np.ndarray:
17-
return np.array([self.x, self.y])
18-
19-
def __add__(self, other):
20-
if isinstance(other, Position):
21-
return Position(self.x + other.x, self.y + other.y)
22-
raise NotImplementedError(
23-
f"Addition not supported for Position and {type(other)}"
24-
)
25-
26-
def __sub__(self, other):
27-
if isinstance(other, Position):
28-
return Position(self.x - other.x, self.y - other.y)
29-
raise NotImplementedError(
30-
f"Subtraction not supported for Position and {type(other)}"
31-
)
32-
33-
def __hash__(self):
34-
return hash((self.x, self.y))
10+
from PathPlanning.TimeBasedPathPlanning.Node import NodePath, Position
3511

3612
@dataclass
3713
class Interval:
@@ -43,6 +19,8 @@ class ObstacleArrangement(Enum):
4319
RANDOM = 0
4420
# Obstacles start in a line in y at center of grid and move side-to-side in x
4521
ARRANGEMENT1 = 1
22+
# Static obstacle arrangement
23+
NARROW_CORRIDOR = 2
4624

4725
"""
4826
Generates a 2d numpy array with lists for elements.
@@ -87,6 +65,8 @@ def __init__(
8765
self.obstacle_paths = self.generate_dynamic_obstacles(num_obstacles)
8866
elif obstacle_arrangement == ObstacleArrangement.ARRANGEMENT1:
8967
self.obstacle_paths = self.obstacle_arrangement_1(num_obstacles)
68+
elif obstacle_arrangement == ObstacleArrangement.NARROW_CORRIDOR:
69+
self.obstacle_paths = self.generate_narrow_corridor_obstacles(num_obstacles)
9070

9171
for i, path in enumerate(self.obstacle_paths):
9272
obs_idx = i + 1 # avoid using 0 - that indicates free space in the grid
@@ -184,6 +164,26 @@ def obstacle_arrangement_1(self, obs_count: int) -> list[list[Position]]:
184164
obstacle_paths.append(path)
185165

186166
return obstacle_paths
167+
168+
def generate_narrow_corridor_obstacles(self, obs_count: int) -> list[list[Position]]:
169+
obstacle_paths = []
170+
171+
for y in range(0, self.grid_size[1]):
172+
if y > obs_count:
173+
break
174+
175+
if y == self.grid_size[1] // 2:
176+
# Skip the middle row
177+
continue
178+
179+
obstacle_path = []
180+
x = self.grid_size[0] // 2 # middle of the grid
181+
for t in range(0, self.time_limit - 1):
182+
obstacle_path.append(Position(x, y))
183+
184+
obstacle_paths.append(obstacle_path)
185+
186+
return obstacle_paths
187187

188188
"""
189189
Check if the given position is valid at time t
@@ -196,11 +196,11 @@ def obstacle_arrangement_1(self, obs_count: int) -> list[list[Position]]:
196196
bool: True if position/time combination is valid, False otherwise
197197
"""
198198
def valid_position(self, position: Position, t: int) -> bool:
199-
# Check if new position is in grid
199+
# Check if position is in grid
200200
if not self.inside_grid_bounds(position):
201201
return False
202202

203-
# Check if new position is not occupied at time t
203+
# Check if position is not occupied at time t
204204
return self.reservation_matrix[position.x, position.y, t] == 0
205205

206206
"""
@@ -289,9 +289,48 @@ def get_safe_intervals_at_cell(self, cell: Position) -> list[Interval]:
289289
# both the time step when it is entering the cell, and the time step when it is leaving the cell.
290290
intervals = [interval for interval in intervals if interval.start_time != interval.end_time]
291291
return intervals
292+
293+
"""
294+
Reserve an agent's path in the grid. Raises an exception if the agent's index is 0, or if a position is
295+
already reserved by a different agent.
296+
"""
297+
def reserve_path(self, node_path: NodePath, agent_index: int):
298+
if agent_index == 0:
299+
raise Exception("Agent index cannot be 0")
300+
301+
for i, node in enumerate(node_path.path):
302+
reservation_finish_time = node.time + 1
303+
if i < len(node_path.path) - 1:
304+
reservation_finish_time = node_path.path[i + 1].time
292305

293-
show_animation = True
306+
self.reserve_position(node.position, agent_index, Interval(node.time, reservation_finish_time))
307+
308+
"""
309+
Reserve a position for the provided agent during the provided time interval.
310+
Raises an exception if the agent's index is 0, or if the position is already reserved by a different agent during the interval.
311+
"""
312+
def reserve_position(self, position: Position, agent_index: int, interval: Interval):
313+
if agent_index == 0:
314+
raise Exception("Agent index cannot be 0")
315+
316+
for t in range(interval.start_time, interval.end_time + 1):
317+
current_reserver = self.reservation_matrix[position.x, position.y, t]
318+
if current_reserver not in [0, agent_index]:
319+
raise Exception(
320+
f"Agent {agent_index} tried to reserve a position already reserved by another agent: {position} at time {t}, reserved by {current_reserver}"
321+
)
322+
self.reservation_matrix[position.x, position.y, t] = agent_index
323+
324+
"""
325+
Clears the initial reservation for an agent by clearing reservations at its start position with its index for
326+
from time 0 to the time limit.
327+
"""
328+
def clear_initial_reservation(self, position: Position, agent_index: int):
329+
for t in range(self.time_limit):
330+
if self.reservation_matrix[position.x, position.y, t] == agent_index:
331+
self.reservation_matrix[position.x, position.y, t] = 0
294332

333+
show_animation = True
295334

296335
def main():
297336
grid = Grid(
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
from functools import total_ordering
3+
import numpy as np
4+
from typing import Sequence
5+
6+
@dataclass(order=True)
7+
class Position:
8+
x: int
9+
y: int
10+
11+
def as_ndarray(self) -> np.ndarray:
12+
return np.array([self.x, self.y])
13+
14+
def __add__(self, other):
15+
if isinstance(other, Position):
16+
return Position(self.x + other.x, self.y + other.y)
17+
raise NotImplementedError(
18+
f"Addition not supported for Position and {type(other)}"
19+
)
20+
21+
def __sub__(self, other):
22+
if isinstance(other, Position):
23+
return Position(self.x - other.x, self.y - other.y)
24+
raise NotImplementedError(
25+
f"Subtraction not supported for Position and {type(other)}"
26+
)
27+
28+
def __hash__(self):
29+
return hash((self.x, self.y))
30+
31+
@dataclass()
32+
# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because
33+
# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. Parent
34+
# index is just used to track the path found by the algorithm, and has no effect on the quality
35+
# of a node.
36+
@total_ordering
37+
class Node:
38+
position: Position
39+
time: int
40+
heuristic: int
41+
parent_index: int
42+
43+
"""
44+
This is what is used to drive node expansion. The node with the lowest value is expanded next.
45+
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
46+
"""
47+
def __lt__(self, other: object):
48+
if not isinstance(other, Node):
49+
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
50+
return (self.time + self.heuristic) < (other.time + other.heuristic)
51+
52+
"""
53+
Note: cost and heuristic are not included in eq or hash, since they will always be the same
54+
for a given (position, time) pair. Including either cost or heuristic would be redundant.
55+
"""
56+
def __eq__(self, other: object):
57+
if not isinstance(other, Node):
58+
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
59+
return self.position == other.position and self.time == other.time
60+
61+
def __hash__(self):
62+
return hash((self.position, self.time))
63+
64+
class NodePath:
65+
path: Sequence[Node]
66+
positions_at_time: dict[int, Position]
67+
# Number of nodes expanded while finding this path
68+
expanded_node_count: int
69+
70+
def __init__(self, path: Sequence[Node], expanded_node_count: int):
71+
self.path = path
72+
self.expanded_node_count = expanded_node_count
73+
74+
self.positions_at_time = {}
75+
for i, node in enumerate(path):
76+
reservation_finish_time = node.time + 1
77+
if i < len(path) - 1:
78+
reservation_finish_time = path[i + 1].time
79+
80+
for t in range(node.time, reservation_finish_time):
81+
self.positions_at_time[t] = node.position
82+
83+
"""
84+
Get the position of the path at a given time
85+
"""
86+
def get_position(self, time: int) -> Position | None:
87+
return self.positions_at_time.get(time)
88+
89+
"""
90+
Time stamp of the last node in the path
91+
"""
92+
def goal_reached_time(self) -> int:
93+
return self.path[-1].time
94+
95+
def __repr__(self):
96+
repr_string = ""
97+
for i, node in enumerate(self.path):
98+
repr_string += f"{i}: {node}\n"
99+
return repr_string

0 commit comments

Comments
 (0)