Skip to content

Commit ba91001

Browse files
committed
dataclasses are 🔥
1 parent 1121dbb commit ba91001

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
"""
2+
This file implements a grid with a 3d reservation matrix with dimensions for x, y, and time. There
3+
is also infrastructure to generate dynamic obstacles that move around the grid. The obstacles' paths
4+
are stored in the reservation matrix on creation.
5+
"""
16
import numpy as np
27
import matplotlib.pyplot as plt
38
from enum import Enum
9+
from dataclasses import dataclass
410

11+
@dataclass(order=True)
512
class Position:
613
x: int
714
y: int
815

9-
def __init__(self, x: int, y: int):
10-
self.x = x
11-
self.y = y
12-
1316
def as_ndarray(self) -> np.ndarray:
1417
return np.array([self.x, self.y])
1518

@@ -27,14 +30,6 @@ def __sub__(self, other):
2730
f"Subtraction not supported for Position and {type(other)}"
2831
)
2932

30-
def __eq__(self, other):
31-
if isinstance(other, Position):
32-
return self.x == other.x and self.y == other.y
33-
return False
34-
35-
def __repr__(self):
36-
return f"Position({self.x}, {self.y})"
37-
3833

3934
class ObstacleArrangement(Enum):
4035
# Random obstacle positions and movements
@@ -46,7 +41,7 @@ class ObstacleArrangement(Enum):
4641
class Grid:
4742
# Set in constructor
4843
grid_size: np.ndarray
49-
grid: np.ndarray
44+
reservation_matrix: np.ndarray
5045
obstacle_paths: list[list[Position]] = []
5146
# Obstacles will never occupy these points. Useful to avoid impossible scenarios
5247
obstacle_avoid_points: list[Position] = []
@@ -68,7 +63,7 @@ def __init__(
6863
self.obstacle_avoid_points = obstacle_avoid_points
6964
self.time_limit = time_limit
7065
self.grid_size = grid_size
71-
self.grid = np.zeros((grid_size[0], grid_size[1], self.time_limit))
66+
self.reservation_matrix = np.zeros((grid_size[0], grid_size[1], self.time_limit))
7267

7368
if num_obstacles > self.grid_size[0] * self.grid_size[1]:
7469
raise Exception("Number of obstacles is greater than grid size!")
@@ -83,8 +78,8 @@ def __init__(
8378
for t, position in enumerate(path):
8479
# Reserve old & new position at this time step
8580
if t > 0:
86-
self.grid[path[t - 1].x, path[t - 1].y, t] = obs_idx
87-
self.grid[position.x, position.y, t] = obs_idx
81+
self.reservation_matrix[path[t - 1].x, path[t - 1].y, t] = obs_idx
82+
self.reservation_matrix[position.x, position.y, t] = obs_idx
8883

8984
"""
9085
Generate dynamic obstacles that move around the grid. Initial positions and movements are random
@@ -191,7 +186,7 @@ def valid_position(self, position: Position, t: int) -> bool:
191186
return False
192187

193188
# Check if new position is not occupied at time t
194-
return self.grid[position.x, position.y, t] == 0
189+
return self.reservation_matrix[position.x, position.y, t] == 0
195190

196191
"""
197192
Returns True if the given position is valid at time t and is not in the set of obstacle_avoid_points

PathPlanning/TimeBasedPathPlanning/SpaceTimeAStar.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
Space-time A* Algorithm
33
This script demonstrates the Space-time A* algorithm for path planning in a grid world with moving obstacles.
4+
This algorithm is different from normal 2D A* in one key way - the cost (often notated as g(n)) is
5+
the number of time steps it took to get to a given node, instead of the number of cells it has
6+
traversed. This ensures the path is time-optimal, while respescting any dynamic obstacles in the environment.
47
58
Reference: https://www.davidsilver.uk/wp-content/uploads/2020/03/coop-path-AIWisdom.pdf
69
"""
@@ -15,31 +18,27 @@
1518
import heapq
1619
from collections.abc import Generator
1720
import random
18-
from functools import total_ordering
21+
from dataclasses import dataclass
22+
1923

2024
# Seed randomness for reproducibility
2125
RANDOM_SEED = 50
2226
random.seed(RANDOM_SEED)
2327
np.random.seed(RANDOM_SEED)
2428

25-
@total_ordering # so the linter will chill about not implementing __gt__, __ge__, etc
29+
@dataclass(order=True)
2630
class Node:
2731
position: Position
2832
time: int
2933
heuristic: int
3034
parent_index: int
3135

32-
def __init__(
33-
self, position: Position, time: int, heuristic: int, parent_index: int
34-
):
35-
self.position = position
36-
self.time = time
37-
self.heuristic = heuristic
38-
self.parent_index = parent_index
39-
4036
"""
4137
This is what is used to drive node expansion. The node with the lowest value is expanded next.
4238
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
39+
40+
This an __eq__ are overridden because we don't care about parent_index when comparing these; two nodes
41+
with the same position, time, and heuristic are equivalent from the search algorithm's perspective.
4342
"""
4443
def __lt__(self, other: object):
4544
if not isinstance(other, Node):
@@ -51,9 +50,6 @@ def __eq__(self, other: object):
5150
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
5251
return self.position == other.position and self.time == other.time
5352

54-
def __repr__(self):
55-
return f"Node(position={self.position}, time={self.time}, heuristic={self.heuristic}, parent_index={self.parent_index})"
56-
5753

5854
class NodePath:
5955
path: list[Node]

0 commit comments

Comments
 (0)