diff --git a/README.md b/README.md index 2f09183..09fbc21 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Note that the environment have been left blank empty here, they should be update For more details you can refer to our [documentation](https://gennav.readthedocs.io/en/latest/index.html). -### ROS Integration +## ROS Integration If you wish to use gennav in a ROS based stack, check out [gennav_ros](https://github.com/ERC-BPGC/gennav_ros). diff --git a/gennav/planners/__init__.py b/gennav/planners/__init__.py index eb0419a..405c778 100644 --- a/gennav/planners/__init__.py +++ b/gennav/planners/__init__.py @@ -1,3 +1,4 @@ from gennav.planners.base import Planner # noqa: F401 -from gennav.planners.prm import PRM # noqa: F401 +from gennav.planners.prm.prm import PRM # noqa: F401 from gennav.planners.rrt.rrt import RRT # noqa: F401 +from gennav.planners.dstar.dstar import DStar # noqa: F401 diff --git a/gennav/planners/dstar/__init__.py b/gennav/planners/dstar/__init__.py new file mode 100644 index 0000000..140fe2f --- /dev/null +++ b/gennav/planners/dstar/__init__.py @@ -0,0 +1 @@ +from gennav.planners.dstar.dstar import DStar # noqa: F401 diff --git a/gennav/planners/dstar/dstar.py b/gennav/planners/dstar/dstar.py new file mode 100644 index 0000000..be22b49 --- /dev/null +++ b/gennav/planners/dstar/dstar.py @@ -0,0 +1,407 @@ +from gennav.planners.base import Planner +from gennav.utils import Trajectory +from gennav.utils.geometry import compute_distance +from gennav.utils.graph import Graph +from gennav.utils.common import Node +from gennav.utils.custom_exceptions import ( + InvalidGoalState, + InvalidStartState, + PathNotFound, +) + + +def c(node1, node2): + if node1.t == float("inf") or node1.t == float("inf"): + return float("inf") + else: + return compute_distance(node1.state.position, node2.state.position) + + +def insert(node, h_): + global open_ + if node.tag == "NEW": + node.k = h_ + node.h = h_ + node.tag = "OPEN" + open_.append(node) + if node not in open_: + node.k = min(node.h, h_) + node.h = h_ + closed.remove(node) + node.tag = "OPEN" + open_.append(node) + else: + node.k = min(node.k, h_) + node.h = h_ + + +class NodeDstar(Node): + """ + Node class for Dstar Node + """ + + # Initialize the class + def __init__(self, **data): + Node.__init__(self, **data) + self.h = None + self.k = None + self.t = None + self.tag = "NEW" + + # Sort nodes + def __lt__(self, other): + return self.k < other.k + + def __le__(self, other): + return self.k <= other.k + + def __ge__(self, other): + return self.k >= other.k + + def __gt__(self, other): + return self.k > other.k + + # Compare nodes + def __eq__(self, other): + if not isinstance(other, NodeDstar): + return False + return self.state.position == other.state.position + + def __hash__(self): + return hash(self.state) + + +class DStar(Planner): + """DStar Class. + + Args: + sampler (gennav.utils.sampler.Sampler): sampler to get random states + r (float): maximum radius to look for neighbours + n (int): total no. of nodes to be sampled in sample_area + """ + + def __init__(self, sampler, r, n): + super(DStar, self) + self.sampler = sampler + self.r = r + self.n = n + self.flag = 0 + + def construct(self, env): + """Constructs DStar graph. + + Args: + env (gennav.envs.Environment): Base class for an envrionment. + + Returns: + graph (gennav.utils.graph): A dict where the keys correspond to nodes and + the values for each key is a list of the neighbour nodes + """ + nodes = [] + graph = Graph() + i = 0 + + # samples points from the sample space until n points + # outside obstacles are obtained + while i < self.n: + sample = self.sampler() + if not env.get_status(sample): + continue + else: + i += 1 + node = NodeDstar(state=sample) + nodes.append(node) + + # finds neighbours for each node in a fixed radius r + for node1 in nodes: + for node2 in nodes: + if node1 != node2: + dist = compute_distance(node1.state.position, node2.state.position) + if dist < self.r: + if env.get_traj_status(Trajectory([node1.state, node2.state])): + if node1 not in graph.nodes: + graph.add_node(node1) + + if node2 not in graph.nodes: + graph.add_node(node2) + + if ( + node2 not in graph.edges[node1] + and node1 not in graph.edges[node2] + ): + graph.add_edge( + node1, node2, + ) + + return graph + + def plan(self, start, goal, env): + """Constructs a graph avoiding obstacles and then plans path from start to goal within the graph. + + Args: + start (gennav.utils.RobotState): tuple with start point coordinates. + goal (gennav.utils.RobotState): tuple with end point coordinates. + env (gennav.envs.Environment): Base class for an envrionment. + Returns: + gennav.utils.Trajectory: The planned path as trajectory + + """ + # construct graph + global graph, traj + graph = self.construct(env) + # find collision free point in graph closest to start_point + min_dist = float("inf") + for node in graph.nodes: + dist = compute_distance(node.state.position, start.position) + traj = Trajectory([node.state, start]) + if dist < min_dist and (env.get_traj_status(traj)): + min_dist = dist + start_node = node + # find collision free point in graph closest to end_point + min_dist = float("inf") + for node in graph.nodes: + dist = compute_distance(node.state.position, goal.position) + traj = Trajectory([node.state, goal]) + if dist < min_dist and (env.get_traj_status(traj)): + min_dist = dist + goal_node = node + global open_, closed + open_ = [] + closed = [] + goal_node.h = 0 + goal_node.k = 0 + open_.append(goal_node) + while len(open_) > 0: + open_.sort() + current_node = open_.pop(0) + current_node.tag = "CLOSED" + closed.append(current_node) + if current_node.state.position == start_node.state.position: + path = [] + path.append(start) + # while current_node.parent is not None: + # print 1 + # path.append(current_node.state) + # h_=float("inf") + # neighbours=graph.edges[current_node] + # for neighbour in neighbours: + # if neighbour.h 0: + open_.sort() + current_node = open_.pop(0) + current_node.tag = "CLOSED" + closed.append(current_node) + if current_node.k >= start_node.h and start_node.tag == "CLOSED": + print 7 + current_node = start_node + path = [] + path.append(start) + # while current_node.parent is not None: + # print 1 + # path.append(current_node.state) + # h_ = float("inf") + # neighbours = graph.edges[current_node] + # for neighbour in neighbours: + # if neighbour.h < h_: + # node_ = neighbour + # h_ = neighbour.h + # current_node = node_ + i = 0 + while current_node.parent is not None: + path.append(current_node.state) + current_node = current_node.parent + if i > 40: + traj = Trajectory(path) + print 88 + return traj + i += 1 + path.append(goal_node.state) + path.append(goal) + traj = Trajectory(path) + print 9 + return traj + if current_node.k < current_node.h and current_node.k is not None: + print 10 + for neighbour in graph.edges[current_node]: + if ( + neighbour.tag != "NEW" + and neighbour.h <= current_node.k + and current_node.h > neighbour.h + c(current_node, neighbour) + ): + current_node.parent = neighbour + current_node.h = neighbour.h + c(current_node, neighbour) + if current_node.k == current_node.h and current_node.k is not None: + print 11 + for neighbour in graph.edges[current_node]: + if ( + neighbour.tag == "NEW" + or ( + neighbour.parent == current_node + and neighbour.h + != current_node.h + c(current_node, neighbour) + ) + or ( + neighbour.parent != current_node + and neighbour.h + > current_node.h + c(current_node, neighbour) + ) + ): + neighbour.parent = current_node + insert(neighbour, current_node.h + c(current_node, neighbour)) + elif current_node.k is not None: + print 12 + for neighbour in graph.edges[current_node]: + if neighbour.tag == "NEW" or ( + neighbour.parent == current_node + and neighbour.h != current_node.h + c(current_node, neighbour) + ): + neighbour.parent = current_node + insert(neighbour, current_node.h + c(current_node, neighbour)) + elif ( + neighbour.parent != current_node + and neighbour.h > current_node.h + c(current_node, neighbour) + ): + insert(current_node, current_node.h) + + elif ( + neighbour.parent != current_node + and current_node.h > neighbour.h + c(current_node, neighbour) + and (neighbour.tag == "CLOSED") + and neighbour.h > current_node.k + ): + insert(neighbour, neighbour.h) + + path = [start] + traj = Trajectory(path) + raise PathNotFound(traj, message="Path contains only one state") diff --git a/gennav/utils/graph.py b/gennav/utils/graph.py index 57b6769..5e89bf4 100644 --- a/gennav/utils/graph.py +++ b/gennav/utils/graph.py @@ -1,5 +1,5 @@ -from collections import defaultdict -from math import sqrt +from gennav.utils import RobotState +from .geometry import compute_distance class Graph: @@ -8,7 +8,7 @@ class Graph: def __init__(self): self.nodes = set() - self.edges = defaultdict(list) + self.edges = {} self.distances = {} def add_node(self, node): @@ -18,6 +18,7 @@ def add_node(self, node): node (gennav.utils.RobotState):to be added to the set of nodes. """ self.nodes.add(node) + self.edges[node] = [] def add_edge( self, node1, node2, @@ -30,7 +31,14 @@ def add_edge( """ self.edges[node1].append(node2) self.edges[node2].append(node1) - self.distances[(node1, node2)] = self.calc_dist(node1, node2) + if isinstance(node1, RobotState): + self.distances[(node1, node2)] = compute_distance( + node1.position, node2.position + ) + else: + self.distances[(node1, node2)] = compute_distance( + node1.state.position, node2.state.position + ) def del_edge(self, node1, node2): """Deletes edge connecting two nodes to the graph. @@ -40,7 +48,7 @@ def del_edge(self, node1, node2): node2 (gennav.utils.RobotState): other end of the edge. """ - if len(self.edges[node1]) == 0: + if node1 not in self.edges: raise ValueError("Edge does not exist.") else: if len(self.edges[node1]) == 1: @@ -50,27 +58,10 @@ def del_edge(self, node1, node2): del self.distances[(node1, node2)] - if len(self.edges[node2]) == 0: + if node1 not in self.edges: raise ValueError("Edge does not exist.") else: if len(self.edges[node2]) == 1: del self.edges[node2] else: self.edges[node2].remove(node1) - - def calc_dist(self, node1, node2): - """Calculates distance between two nodes. - - Args: - node1 (gennav.utils.RobotState): one end of the edge. - node2 (gennav.utils.RobotState): other end of the edge. - - Returns: - dist (float): distance between two nodes. - """ - self.dist = sqrt( - (node1.position.x - node2.position.x) ** 2 - + (node1.position.y - node2.position.y) ** 2 - + (node1.position.z - node2.position.z) ** 2 - ) - return self.dist diff --git a/gennav/utils/visualisation.py b/gennav/utils/visualisation.py index 538fa7e..cda3fcc 100644 --- a/gennav/utils/visualisation.py +++ b/gennav/utils/visualisation.py @@ -1,6 +1,7 @@ from descartes import PolygonPatch from matplotlib import pyplot as plt from shapely.geometry import Polygon +from gennav.utils import RobotState def visualize_graph(graph, env): @@ -15,13 +16,27 @@ def visualize_graph(graph, env): # Clear the figure plt.clf() # Plot each edge of the tree + chk = False for node in graph.nodes: - for neighbour in graph.edges[node]: - plt.plot( - [node.position.x, neighbour.position.x], - [node.position.y, neighbour.position.y], - color="red", - ) + if isinstance(node, RobotState): + chk = True + break + if chk: + for node in graph.nodes: + for neighbour in graph.edges[node]: + plt.plot( + [node.position.x, neighbour.position.x], + [node.position.y, neighbour.position.y], + color="red", + ) + else: + for node in graph.nodes: + for neighbour in graph.edges[node]: + plt.plot( + [node.state.position.x, neighbour.state.position.x], + [node.state.position.y, neighbour.state.position.y], + color="red", + ) # Draw the obstacles in the environment for obstacle in obstacle_list: diff --git a/tests/test_planners/dstar_test.py b/tests/test_planners/dstar_test.py new file mode 100644 index 0000000..be20592 --- /dev/null +++ b/tests/test_planners/dstar_test.py @@ -0,0 +1,38 @@ +from gennav.utils.graph import Graph +from gennav.planners.dstar.dstar import DStar +from gennav.envs import PolygonEnv +from gennav.utils import RobotState +from gennav.utils.geometry import Point +from gennav.utils.samplers import UniformRectSampler + + +def test_dstar(): + obstacles = [[(8, 5), (7, 8), (2, 9), (3, 5)], [(3, 3), (3, 5), (5, 5), (5, 3)]] + + sampler = UniformRectSampler(-5, 15, -5, 15) + poly = PolygonEnv() + start = RobotState(position=Point(0, 0)) + goal = RobotState(position=Point(12, 10)) + my_tree = DStar(sampler=sampler, r=3, n=75) + + poly.update(obstacles) + path = my_tree.plan(start, goal, poly) + from gennav.envs.common import visualize_path + + visualize_path(path, poly) + obstacles = [ + [(8, 5), (7, 8), (2, 9), (3, 5)], + [(3, 3), (3, 5), (5, 5), (5, 3)], + [(10, 8), (12, 8), (11, 6)], + ] + + # obstacles = [[(10, 7), (9, 10), (4, 11), (5, 7)], [(5, 5), (5, 7), (7, 7), (7, 5)],[(9,8),(13,8),(11,5)]] + + poly.update(obstacles) + path_new = my_tree.replan(start, goal, poly) + from gennav.envs.common import visualize_path + + visualize_path(path_new, poly) + + +test_dstar() diff --git a/tests/test_planners/prm_test.py b/tests/test_planners/prm_test.py index 28b6e2f..c9b52c5 100644 --- a/tests/test_planners/prm_test.py +++ b/tests/test_planners/prm_test.py @@ -42,7 +42,7 @@ def test_prm_construct(): ], ] - sampler = UniformRectSampler(-5, -5, 15, 15) + sampler = UniformRectSampler(-5, 15, -5, 15) poly = PolygonEnv() my_tree = PRM(sampler=sampler, r=5, n=50)