2020import random
2121from dataclasses import dataclass
2222from functools import total_ordering
23-
23+ import time
2424
2525# Seed randomness for reproducibility
2626RANDOM_SEED = 50
@@ -53,6 +53,8 @@ def __eq__(self, other: object):
5353 return NotImplementedError (f"Cannot compare Node with object of type: { type (other )} " )
5454 return self .position == other .position and self .time == other .time
5555
56+ def __hash__ (self ):
57+ return hash ((self .position , self .time , self .heuristic ))
5658
5759class NodePath :
5860 path : list [Node ]
@@ -98,7 +100,8 @@ def plan(self, verbose: bool = False) -> NodePath:
98100 open_set , Node (self .start , 0 , self .calculate_heuristic (self .start ), - 1 )
99101 )
100102
101- expanded_set : list [Node ] = []
103+ expanded_list : list [Node ] = []
104+ expanded_set : set [Node ] = set ()
102105 while open_set :
103106 expanded_node : Node = heapq .heappop (open_set )
104107 if verbose :
@@ -110,23 +113,24 @@ def plan(self, verbose: bool = False) -> NodePath:
110113 continue
111114
112115 if expanded_node .position == self .goal :
113- print (f"Found path to goal after { len (expanded_set )} expansions" )
116+ print (f"Found path to goal after { len (expanded_list )} expansions" )
114117 path = []
115118 path_walker : Node = expanded_node
116119 while True :
117120 path .append (path_walker )
118121 if path_walker .parent_index == - 1 :
119122 break
120- path_walker = expanded_set [path_walker .parent_index ]
123+ path_walker = expanded_list [path_walker .parent_index ]
121124
122125 # reverse path so it goes start -> goal
123126 path .reverse ()
124127 return NodePath (path )
125128
126- expanded_idx = len (expanded_set )
127- expanded_set .append (expanded_node )
129+ expanded_idx = len (expanded_list )
130+ expanded_list .append (expanded_node )
131+ expanded_set .add (expanded_node )
128132
129- for child in self .generate_successors (expanded_node , expanded_idx , verbose ):
133+ for child in self .generate_successors (expanded_node , expanded_idx , verbose , expanded_set ):
130134 heapq .heappush (open_set , child )
131135
132136 raise Exception ("No path found" )
@@ -135,7 +139,7 @@ def plan(self, verbose: bool = False) -> NodePath:
135139 Generate possible successors of the provided `parent_node`
136140 """
137141 def generate_successors (
138- self , parent_node : Node , parent_node_idx : int , verbose : bool
142+ self , parent_node : Node , parent_node_idx : int , verbose : bool , expanded_set : set [ Node ]
139143 ) -> Generator [Node , None , None ]:
140144 diffs = [
141145 Position (0 , 0 ),
@@ -146,13 +150,17 @@ def generate_successors(
146150 ]
147151 for diff in diffs :
148152 new_pos = parent_node .position + diff
153+ new_node = Node (
154+ new_pos ,
155+ parent_node .time + 1 ,
156+ self .calculate_heuristic (new_pos ),
157+ parent_node_idx ,
158+ )
159+
160+ if new_node in expanded_set :
161+ continue
162+
149163 if self .grid .valid_position (new_pos , parent_node .time + 1 ):
150- new_node = Node (
151- new_pos ,
152- parent_node .time + 1 ,
153- self .calculate_heuristic (new_pos ),
154- parent_node_idx ,
155- )
156164 if verbose :
157165 print ("\t New successor node: " , new_node )
158166 yield new_node
@@ -166,9 +174,12 @@ def calculate_heuristic(self, position) -> int:
166174verbose = False
167175
168176def main ():
169- start = Position (1 , 11 )
177+ start = Position (1 , 5 )
170178 goal = Position (19 , 19 )
171179 grid_side_length = 21
180+
181+ start_time = time .time ()
182+
172183 grid = Grid (
173184 np .array ([grid_side_length , grid_side_length ]),
174185 num_obstacles = 40 ,
@@ -179,6 +190,9 @@ def main():
179190 planner = SpaceTimeAStar (grid , start , goal )
180191 path = planner .plan (verbose )
181192
193+ runtime = time .time () - start_time
194+ print (f"Planning took: { runtime :.5f} seconds" )
195+
182196 if verbose :
183197 print (f"Path: { path } " )
184198
0 commit comments