22from scipy .sparse import csr_matrix
33from scipy .sparse .csgraph import dijkstra
44
5+
56class gridworld :
67 """A class for making gridworlds"""
78 def __init__ (self , image , targetx , targety ):
@@ -25,6 +26,8 @@ def __init__(self, image, targetx, targety):
2526
2627
2728 def set_vals (self ):
29+ # Setup function to initialize all necessary
30+ # data
2831 row_obs , col_obs = np .where (self .image == 0 )
2932 row_free , col_free = np .where (self .image != 0 )
3033 self .obstacles = [row_obs , col_obs ]
@@ -125,24 +128,28 @@ def set_vals(self):
125128
126129
127130 def get_graph (self ):
131+ # Returns graph
128132 G = self .G
129133 W = self .W [self .W != 0 ]
130134 return G , W
131135
132136
133137 def get_graph_inv (self ):
138+ # Returns transpose of graph
134139 G = self .G .T
135140 W = self .W .T
136141 return G , W
137142
138143
139144 def val_2_image (self , val ):
145+ # Zeros for obstacles, val for free space
140146 im = np .zeros ((self .n_row , self .n_col ))
141147 im [self .freespace [0 ], self .freespace [1 ]] = val
142148 return im
143149
144150
145151 def get_value_prior (self ):
152+ # Returns value prior for gridworld
146153 s_map_col , s_map_row = np .meshgrid (np .arange (0 ,self .n_col ),
147154 np .arange (0 , self .n_row ))
148155 im = np .sqrt (np .square (s_map_col - self .targety )
@@ -151,30 +158,37 @@ def get_value_prior(self):
151158
152159
153160 def get_reward_prior (self ):
161+ # Returns reward prior for gridworld
154162 im = - 1 * np .ones ((self .n_row , self .n_col ))
155163 im [self .targetx , self .targety ] = 10
156164 return im
157165
158166
159167 def t_get_reward_prior (self ):
168+ # Returns reward prior as needed for
169+ # dataset generation
160170 im = np .zeros ((self .n_row , self .n_col ))
161171 im [self .targetx , self .targety ] = 10
162172 return im
163173
164174
165175 def get_state_image (self , row , col ):
176+ # Zeros everywhere except [row,col]
166177 im = np .zeros ((self .n_row , self .n_col ))
167178 im [row , col ] = 1
168179 return im
169180
170181
171182 def map_ind_to_state (self , row , col ):
183+ # Takes [row, col] and maps to a state
172184 rw = np .where (self .state_map_row == row )
173185 cl = np .where (self .state_map_col == col )
174186 return np .intersect1d (rw , cl )[0 ]
175187
176188
177189 def get_coords (self , states ):
190+ # Given a state or states, returns
191+ # [row,col] pairs for the state(s)
178192 non_obstacles = np .ravel_multi_index (
179193 [self .freespace [0 ], self .freespace [1 ]],
180194 (self .n_row ,self .n_col ), order = 'F' )
@@ -186,6 +200,7 @@ def get_coords(self, states):
186200
187201
188202 def rand_choose (self , in_vec ):
203+ # Samples
189204 if len (in_vec .shape ) > 1 :
190205 if in_vec .shape [1 ] == 1 :
191206 in_vec = in_vec .T
@@ -197,6 +212,8 @@ def rand_choose(self, in_vec):
197212
198213
199214 def next_state_prob (self , s , a ):
215+ # Gets next state probability for
216+ # a given action (a)
200217 if hasattr (a , "__iter__" ):
201218 p = np .squeeze (self .P [s , :, a ])
202219 else :
@@ -205,16 +222,22 @@ def next_state_prob(self, s, a):
205222
206223
207224 def sample_next_state (self , s , a ):
225+ # Gets the next state given the
226+ # current state (s) and an
227+ # action (a)
208228 vec = self .next_state_prob (s , a )
209229 result = self .rand_choose (vec )
210230 return result
211231
212232
213233 def get_size (self ):
234+ # Returns domain size
214235 return self .n_row , self .n_col
215236
216237
217238 def north (self , row , col ):
239+ # Returns new [row,col]
240+ # if we take the action
218241 new_row = np .max ([row - 1 , 0 ])
219242 new_col = col
220243 if self .image [new_row , new_col ] == 0 :
@@ -224,6 +247,8 @@ def north(self, row, col):
224247
225248
226249 def northeast (self , row , col ):
250+ # Returns new [row,col]
251+ # if we take the action
227252 new_row = np .max ([row - 1 , 0 ])
228253 new_col = np .min ([col + 1 , self .n_col - 1 ])
229254 if self .image [new_row , new_col ] == 0 :
@@ -233,6 +258,8 @@ def northeast(self, row, col):
233258
234259
235260 def northwest (self , row , col ):
261+ # Returns new [row,col]
262+ # if we take the action
236263 new_row = np .max ([row - 1 , 0 ])
237264 new_col = np .max ([col - 1 , 0 ])
238265 if self .image [new_row , new_col ] == 0 :
@@ -242,6 +269,8 @@ def northwest(self, row, col):
242269
243270
244271 def south (self , row , col ):
272+ # Returns new [row,col]
273+ # if we take the action
245274 new_row = np .min ([row + 1 , self .n_row - 1 ])
246275 new_col = col
247276 if self .image [new_row , new_col ] == 0 :
@@ -251,6 +280,8 @@ def south(self, row, col):
251280
252281
253282 def southeast (self , row , col ):
283+ # Returns new [row,col]
284+ # if we take the action
254285 new_row = np .min ([row + 1 , self .n_row - 1 ])
255286 new_col = np .min ([col + 1 , self .n_col - 1 ])
256287 if self .image [new_row , new_col ] == 0 :
@@ -260,6 +291,8 @@ def southeast(self, row, col):
260291
261292
262293 def southwest (self , row , col ):
294+ # Returns new [row,col]
295+ # if we take the action
263296 new_row = np .min ([row + 1 , self .n_row - 1 ])
264297 new_col = np .max ([col - 1 , 0 ])
265298 if self .image [new_row , new_col ] == 0 :
@@ -269,6 +302,8 @@ def southwest(self, row, col):
269302
270303
271304 def east (self , row , col ):
305+ # Returns new [row,col]
306+ # if we take the action
272307 new_row = row
273308 new_col = np .min ([col + 1 , self .n_col - 1 ])
274309 if self .image [new_row , new_col ] == 0 :
@@ -278,6 +313,8 @@ def east(self, row, col):
278313
279314
280315 def west (self , row , col ):
316+ # Returns new [row,col]
317+ # if we take the action
281318 new_row = row
282319 new_col = np .max ([col - 1 , 0 ])
283320 if self .image [new_row , new_col ] == 0 :
@@ -307,6 +344,9 @@ def neighbors(self, row, col):
307344
308345
309346def trace_path (pred , source , target ):
347+ # traces back shortest path from
348+ # source to target given pred
349+ # (a predicessor list)
310350 max_len = 1000
311351 path = np .zeros ((max_len , 1 ))
312352 i = max_len - 1
@@ -325,6 +365,8 @@ def trace_path(pred, source, target):
325365
326366
327367def sample_trajectory (M , n_states ):
368+ # Samples trajectories from random nodes
369+ # in our domain (M)
328370 G , W = M .get_graph_inv ()
329371 N = G .shape [0 ]
330372 if N >= n_states :
0 commit comments