diff --git a/gennav/planners/graph_search_algorithms/grassfire.py b/gennav/planners/graph_search_algorithms/grassfire.py new file mode 100644 index 0000000..8486f40 --- /dev/null +++ b/gennav/planners/graph_search_algorithms/grassfire.py @@ -0,0 +1,178 @@ +import numpy as np +import random +import math + +PI = math.pi + +''' +referance: https://nrsyed.com/2017/12/30/animating-the-grassfire-path-planning-algorithm/ +''' +class Grassfire: + '''Class is a container for constants and methods we'll use + to create, modify, and plot a 2D grid of pixels + demonstrating the grassfire path-planning algorithm. + ''' + ''' + the grid is rxc array ,where each element contains + a value corresponding to the 5 constants below + + ''' + START = 0 + DEST = -1 # destination + UNVIS = -2 # unvisited + OBST = -3 # obstacle + PATH = -4 + + # Each of the above cell values is represented by an RGB color + # on the plot. COLOR_VIS refers to visited cells (value > 0). + COLOR_START = np.array([0, 0.75, 0]) + COLOR_DEST = np.array([0.75, 0, 0]) + COLOR_UNVIS = np.array([1, 1, 1]) + COLOR_VIS = np.array([0, 0.5, 1]) + COLOR_OBST = np.array([0, 0, 0]) + COLOR_PATH = np.array([1, 1, 0]) + + + def select_grid(self, rows, cols, obstacleList, Start, end): + '''Return a 2D numpy array representing a grid of randomly placed + obstacles (where the likelihood of any cell being an obstacle + is given by obstacleProb) and randomized start/destination cells. + ''' + global start, dest + start = Start + dest = end + + n=len(obstacleList) + grid = Grassfire.UNVIS * np.ones((rows, cols), dtype=np.int) + i=0 + for i in range(0,n): + grid[obstacleList[i]] = self.OBST + + + # Remove existing start and dest cells, if any. + grid[grid == Grassfire.START] = Grassfire.UNVIS + grid[grid == Grassfire.DEST] = Grassfire.UNVIS + + + grid[start] = Grassfire.START + grid[dest] = Grassfire.DEST + # Randomly set start and destination cells. + #self.set_start_dest(grid) + return grid + + + + def color_grid(self, grid): + '''Return MxNx3 pixel array ("color grid") corresponding to a grid.''' + (rows, cols) = grid.shape + colorGrid = np.zeros((rows, cols, 3), dtype=np.float) + + colorGrid[grid == Grassfire.OBST, :] = Grassfire.COLOR_OBST + colorGrid[grid == Grassfire.UNVIS, :] = Grassfire.COLOR_UNVIS + colorGrid[grid == Grassfire.START, :] = Grassfire.COLOR_START + colorGrid[grid == Grassfire.DEST, :] = Grassfire.COLOR_DEST + colorGrid[grid > Grassfire.START, :] = Grassfire.COLOR_VIS + colorGrid[grid == Grassfire.PATH, :] = Grassfire.COLOR_PATH + return colorGrid + + def reset_grid(self, grid): + '''Reset cells that are not OBST, START, or DEST to UNVIS.''' + cellsToReset = ~((grid == Grassfire.OBST) + (grid == Grassfire.START) + + (grid == Grassfire.DEST)) + grid[cellsToReset] = Grassfire.UNVIS + + def _check_adjacent(self, grid, cell, currentDepth): + '''For given grid, check the cells adjacent to a given + cell. If any have a depth (positive int) greater + than the current depth, update them with the current + depth, where depth represents distance from start cell. + If destination found, return DEST constant; else, return + number of adjacent cells updated. + ''' + (rows, cols) = grid.shape + + # Track how many adjacent cells are updated. + numCellsUpdated = 0 + + # From the current cell, examine, using sin and cos: + # cell to right (col + 1), cell below (row + 1), + # cell to left (col - 1), cell above (row - 1). + for i in range(4): + rowToCheck = cell[0] + int(math.sin((PI/2) * i)) + colToCheck = cell[1] + int(math.cos((PI/2) * i)) + + # Ensure cell is within bounds of grid. + if not (0 <= rowToCheck < rows and 0 <= colToCheck < cols): + continue + # Check if destination found. + elif grid[rowToCheck, colToCheck] == Grassfire.DEST: + return Grassfire.DEST + # If adjacent cell unvisited or depth > currentDepth + 1, + # mark with new depth. + elif (grid[rowToCheck, colToCheck] == Grassfire.UNVIS + or grid[rowToCheck, colToCheck] > currentDepth + 1): + grid[rowToCheck, colToCheck] = currentDepth + 1 + numCellsUpdated += 1 + return numCellsUpdated + + def _backtrack(self, grid, cell, currentDepth): + '''This function is used if the destination is found. Similar + to _check_adjacent(), but returns coordinates of first + surrounding cell whose value matches "currentDepth", ie, + the next cell along the path from destination to start. + ''' + (rows, cols) = grid.shape + + for i in range(4): + rowToCheck = cell[0] + int(math.sin((PI/2) * i)) + colToCheck = cell[1] + int(math.cos((PI/2) * i)) + + if not (0 <= rowToCheck < rows and 0 <= colToCheck < cols): + continue + elif grid[rowToCheck, colToCheck] == currentDepth: + nextCell = (rowToCheck, colToCheck) + grid[nextCell] = Grassfire.PATH + return nextCell + + def find_path(self, grid): + '''Execute grassfire algorithm by spreading from the start cell out. + If destination is found, use _backtrack() to trace path from + destination back to start. Returns a generator function to + allow stepping through and animating the algorithm. + ''' + nonlocalDict = {'grid': grid} + def find_path_generator(): + grid = nonlocalDict['grid'] + depth = 0 + destFound = False + cellsExhausted = False + + while (not destFound) and (not cellsExhausted): + numCellsModified = 0 + depthIndices = np.where(grid == depth) + matchingCells = list(zip(depthIndices[0], depthIndices[1])) + + for cell in matchingCells: + adjacentVal = self._check_adjacent(grid, cell, depth) + if adjacentVal == Grassfire.DEST: + destFound = True + break + else: + numCellsModified += adjacentVal + + if numCellsModified == 0: + cellsExhausted = True + elif not destFound: + depth += 1 + yield + + if destFound: + destCell = np.where(grid == Grassfire.DEST) + backtrackCell = (destCell[0].item(), destCell[1].item()) + while depth > 0: + # Work backwards until return to start cell. + nextCell = self._backtrack(grid, backtrackCell, depth) + backtrackCell = nextCell + depth -= 1 + yield + return find_path_generator \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index 7e33acd..1ec42a8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ from tests.test_planner import astar_test # noqa: F401 from tests.test_planner import prm_test # noqa: F401 from tests.test_planner import rrt_test # noqa: F401 +from tests.test_planner import grassfire_test diff --git a/tests/test_planner/grassfire_test.py b/tests/test_planner/grassfire_test.py new file mode 100644 index 0000000..18007a4 --- /dev/null +++ b/tests/test_planner/grassfire_test.py @@ -0,0 +1,81 @@ +from __future__ import division +import math +''' +reference:referance: https://nrsyed.com/2017/12/30/animating-the-grassfire-path-planning-algorithm/ +''' +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from grassfire import Grassfire + +# Initialize grid rows, columns, and obstacle probability. +rows = 8 +cols = 8 +obstlist=[(5,5),(3,4),(4,4),(3,3),(2,2),(2,7),(6,6)] +start=(1,1) +end=(6,7) + +# Instantiate Grassfire class. Initialize a grid and colorGrid. +Grassfire = Grassfire() +grid = Grassfire.select_grid(rows, cols, obstlist, start, end) + +colorGrid = Grassfire.color_grid(grid) + +# Initialize figure, imshow object, and axis. +fig = plt.figure() +gridPlot = plt.imshow(colorGrid, interpolation='nearest') +ax = gridPlot._axes +ax.grid(visible=True, ls='solid', color='k', lw=1.5) +ax.set_xticklabels([]) +ax.set_yticklabels([]) + +# Initialize text annotations to display obstacle probability, rows, cols. +obstText = ax.annotate('', (0.15, 0.01), xycoords='figure fraction') +colText = ax.annotate('', (0.15, 0.04), xycoords='figure fraction') +rowText = ax.annotate('', (0.15, 0.07), xycoords='figure fraction') + +def set_axis_properties(rows, cols): + '''Set axis/imshow plot properties based on number of rows, cols.''' + ax.set_xlim((0, cols)) + ax.set_ylim((rows, 0)) + ax.set_xticks(np.arange(0, cols+1, 1)) + ax.set_yticks(np.arange(0, rows+1, 1)) + gridPlot.set_extent([0, cols, 0, rows]) + +def update_annotations(rows, cols): + '''Update annotations with obstacle probability, rows, cols.''' + + colText.set_text('Rows: {:d}'.format(rows)) + rowText.set_text('Columns: {:d}'.format(cols)) + +set_axis_properties(rows, cols) +update_annotations(rows, cols) + +# Disable default figure key bindings. +fig.canvas.mpl_disconnect(fig.canvas.manager.key_press_handler_id) + + + +# Functions init_anim() and update_anim() are for use with FuncAnimation. +def init_anim(): + '''Plot grid in its initial state by resetting "grid".''' + Grassfire.reset_grid(grid) + colorGrid = Grassfire.color_grid(grid) + gridPlot.set_data(colorGrid) + +def update_anim(dummyFrameArgument): + '''Update plot based on values in "grid" ("grid" is updated + by the generator--this function simply passes "grid" to + the color_grid() function to get an image array). + ''' + colorGrid = Grassfire.color_grid(grid) + gridPlot.set_data(colorGrid) + +# Create animation object. Supply generator function to frames. +ani = animation.FuncAnimation(fig, update_anim, + init_func=init_anim, frames=Grassfire.find_path(grid), + repeat=False, interval=150) + +# Turn on interactive plotting and show figure. +plt.ion() +plt.show(block=True) \ No newline at end of file