Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions gennav/planners/graph_search_algorithms/grassfire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import numpy as np
import random
import math

PI = math.pi

'''
referance: https://nrsyed.com/2017/12/30/animating-the-grassfire-path-planning-algorithm/
'''
class Grassfire:
'''Class is a container for constants and methods we'll use
to create, modify, and plot a 2D grid of pixels
demonstrating the grassfire path-planning algorithm.
'''
'''
the grid is rxc array ,where each element contains
a value corresponding to the 5 constants below

'''
START = 0
DEST = -1 # destination
UNVIS = -2 # unvisited
OBST = -3 # obstacle
PATH = -4

# Each of the above cell values is represented by an RGB color
# on the plot. COLOR_VIS refers to visited cells (value > 0).
COLOR_START = np.array([0, 0.75, 0])
COLOR_DEST = np.array([0.75, 0, 0])
COLOR_UNVIS = np.array([1, 1, 1])
COLOR_VIS = np.array([0, 0.5, 1])
COLOR_OBST = np.array([0, 0, 0])
COLOR_PATH = np.array([1, 1, 0])


def select_grid(self, rows, cols, obstacleList, Start, end):
'''Return a 2D numpy array representing a grid of randomly placed
obstacles (where the likelihood of any cell being an obstacle
is given by obstacleProb) and randomized start/destination cells.
'''
global start, dest
start = Start
dest = end

n=len(obstacleList)
grid = Grassfire.UNVIS * np.ones((rows, cols), dtype=np.int)
i=0
for i in range(0,n):
grid[obstacleList[i]] = self.OBST


# Remove existing start and dest cells, if any.
grid[grid == Grassfire.START] = Grassfire.UNVIS
grid[grid == Grassfire.DEST] = Grassfire.UNVIS


grid[start] = Grassfire.START
grid[dest] = Grassfire.DEST
# Randomly set start and destination cells.
#self.set_start_dest(grid)
return grid



def color_grid(self, grid):
'''Return MxNx3 pixel array ("color grid") corresponding to a grid.'''
(rows, cols) = grid.shape
colorGrid = np.zeros((rows, cols, 3), dtype=np.float)

colorGrid[grid == Grassfire.OBST, :] = Grassfire.COLOR_OBST
colorGrid[grid == Grassfire.UNVIS, :] = Grassfire.COLOR_UNVIS
colorGrid[grid == Grassfire.START, :] = Grassfire.COLOR_START
colorGrid[grid == Grassfire.DEST, :] = Grassfire.COLOR_DEST
colorGrid[grid > Grassfire.START, :] = Grassfire.COLOR_VIS
colorGrid[grid == Grassfire.PATH, :] = Grassfire.COLOR_PATH
return colorGrid

def reset_grid(self, grid):
'''Reset cells that are not OBST, START, or DEST to UNVIS.'''
cellsToReset = ~((grid == Grassfire.OBST) + (grid == Grassfire.START)
+ (grid == Grassfire.DEST))
grid[cellsToReset] = Grassfire.UNVIS

def _check_adjacent(self, grid, cell, currentDepth):
'''For given grid, check the cells adjacent to a given
cell. If any have a depth (positive int) greater
than the current depth, update them with the current
depth, where depth represents distance from start cell.
If destination found, return DEST constant; else, return
number of adjacent cells updated.
'''
(rows, cols) = grid.shape

# Track how many adjacent cells are updated.
numCellsUpdated = 0

# From the current cell, examine, using sin and cos:
# cell to right (col + 1), cell below (row + 1),
# cell to left (col - 1), cell above (row - 1).
for i in range(4):
rowToCheck = cell[0] + int(math.sin((PI/2) * i))
colToCheck = cell[1] + int(math.cos((PI/2) * i))

# Ensure cell is within bounds of grid.
if not (0 <= rowToCheck < rows and 0 <= colToCheck < cols):
continue
# Check if destination found.
elif grid[rowToCheck, colToCheck] == Grassfire.DEST:
return Grassfire.DEST
# If adjacent cell unvisited or depth > currentDepth + 1,
# mark with new depth.
elif (grid[rowToCheck, colToCheck] == Grassfire.UNVIS
or grid[rowToCheck, colToCheck] > currentDepth + 1):
grid[rowToCheck, colToCheck] = currentDepth + 1
numCellsUpdated += 1
return numCellsUpdated

def _backtrack(self, grid, cell, currentDepth):
'''This function is used if the destination is found. Similar
to _check_adjacent(), but returns coordinates of first
surrounding cell whose value matches "currentDepth", ie,
the next cell along the path from destination to start.
'''
(rows, cols) = grid.shape

for i in range(4):
rowToCheck = cell[0] + int(math.sin((PI/2) * i))
colToCheck = cell[1] + int(math.cos((PI/2) * i))

if not (0 <= rowToCheck < rows and 0 <= colToCheck < cols):
continue
elif grid[rowToCheck, colToCheck] == currentDepth:
nextCell = (rowToCheck, colToCheck)
grid[nextCell] = Grassfire.PATH
return nextCell

def find_path(self, grid):
'''Execute grassfire algorithm by spreading from the start cell out.
If destination is found, use _backtrack() to trace path from
destination back to start. Returns a generator function to
allow stepping through and animating the algorithm.
'''
nonlocalDict = {'grid': grid}
def find_path_generator():
grid = nonlocalDict['grid']
depth = 0
destFound = False
cellsExhausted = False

while (not destFound) and (not cellsExhausted):
numCellsModified = 0
depthIndices = np.where(grid == depth)
matchingCells = list(zip(depthIndices[0], depthIndices[1]))

for cell in matchingCells:
adjacentVal = self._check_adjacent(grid, cell, depth)
if adjacentVal == Grassfire.DEST:
destFound = True
break
else:
numCellsModified += adjacentVal

if numCellsModified == 0:
cellsExhausted = True
elif not destFound:
depth += 1
yield

if destFound:
destCell = np.where(grid == Grassfire.DEST)
backtrackCell = (destCell[0].item(), destCell[1].item())
while depth > 0:
# Work backwards until return to start cell.
nextCell = self._backtrack(grid, backtrackCell, depth)
backtrackCell = nextCell
depth -= 1
yield
return find_path_generator
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests.test_planner import astar_test # noqa: F401
from tests.test_planner import prm_test # noqa: F401
from tests.test_planner import rrt_test # noqa: F401
from tests.test_planner import grassfire_test
81 changes: 81 additions & 0 deletions tests/test_planner/grassfire_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import division
import math
'''
reference:referance: https://nrsyed.com/2017/12/30/animating-the-grassfire-path-planning-algorithm/
'''
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from grassfire import Grassfire

# Initialize grid rows, columns, and obstacle probability.
rows = 8
cols = 8
obstlist=[(5,5),(3,4),(4,4),(3,3),(2,2),(2,7),(6,6)]
start=(1,1)
end=(6,7)

# Instantiate Grassfire class. Initialize a grid and colorGrid.
Grassfire = Grassfire()
grid = Grassfire.select_grid(rows, cols, obstlist, start, end)

colorGrid = Grassfire.color_grid(grid)

# Initialize figure, imshow object, and axis.
fig = plt.figure()
gridPlot = plt.imshow(colorGrid, interpolation='nearest')
ax = gridPlot._axes
ax.grid(visible=True, ls='solid', color='k', lw=1.5)
ax.set_xticklabels([])
ax.set_yticklabels([])

# Initialize text annotations to display obstacle probability, rows, cols.
obstText = ax.annotate('', (0.15, 0.01), xycoords='figure fraction')
colText = ax.annotate('', (0.15, 0.04), xycoords='figure fraction')
rowText = ax.annotate('', (0.15, 0.07), xycoords='figure fraction')

def set_axis_properties(rows, cols):
'''Set axis/imshow plot properties based on number of rows, cols.'''
ax.set_xlim((0, cols))
ax.set_ylim((rows, 0))
ax.set_xticks(np.arange(0, cols+1, 1))
ax.set_yticks(np.arange(0, rows+1, 1))
gridPlot.set_extent([0, cols, 0, rows])

def update_annotations(rows, cols):
'''Update annotations with obstacle probability, rows, cols.'''

colText.set_text('Rows: {:d}'.format(rows))
rowText.set_text('Columns: {:d}'.format(cols))

set_axis_properties(rows, cols)
update_annotations(rows, cols)

# Disable default figure key bindings.
fig.canvas.mpl_disconnect(fig.canvas.manager.key_press_handler_id)



# Functions init_anim() and update_anim() are for use with FuncAnimation.
def init_anim():
'''Plot grid in its initial state by resetting "grid".'''
Grassfire.reset_grid(grid)
colorGrid = Grassfire.color_grid(grid)
gridPlot.set_data(colorGrid)

def update_anim(dummyFrameArgument):
'''Update plot based on values in "grid" ("grid" is updated
by the generator--this function simply passes "grid" to
the color_grid() function to get an image array).
'''
colorGrid = Grassfire.color_grid(grid)
gridPlot.set_data(colorGrid)

# Create animation object. Supply generator function to frames.
ani = animation.FuncAnimation(fig, update_anim,
init_func=init_anim, frames=Grassfire.find_path(grid),
repeat=False, interval=150)

# Turn on interactive plotting and show figure.
plt.ion()
plt.show(block=True)