1+ import sys
12import argparse
23
34import matplotlib .pyplot as plt
1516from generators .obstacle_gen import *
1617
1718
18- def main (config , n_domains = 100 , max_obs = 10 ,
19+ def main (config , n_domains = 5000 , max_obs = 40 ,
1920 max_obs_size = None , n_traj = 1 , n_actions = 8 ):
20-
21+ # Correct vs total:
22+ correct , total = 0.0 , 0.0
2123 # Automatic swith of GPU mode if available
2224 use_GPU = torch .cuda .is_available ()
2325 vin = torch .load (config .weights )
@@ -95,7 +97,15 @@ def main(config, n_domains=100, max_obs=10,
9597 pred_traj [j + 1 :,1 ] = nc
9698 break
9799 # Plot optimal and predicted path (also start, end)
98- visualize (G .image .T , states_xy [i ], pred_traj )
100+ if pred_traj [- 1 , 0 ] == goal [0 ] and pred_traj [- 1 , 1 ] == goal [1 ]:
101+ correct += 1
102+ total += 1
103+ if config .plot == True :
104+ visualize (G .image .T , states_xy [i ], pred_traj )
105+ sys .stdout .write ("\r " + str (int ((float (dom )/ n_domains ) * 100.0 )) + "%" )
106+ sys .stdout .flush ()
107+ sys .stdout .write ("\n " )
108+ print ('Rollout Accuracy: {:.2f}%' .format (100 * (correct / total )))
99109
100110
101111def visualize (dom , states_xy , pred_traj ):
@@ -122,6 +132,7 @@ def visualize(dom, states_xy, pred_traj):
122132 type = str ,
123133 default = 'trained/vin_8x8.pth' ,
124134 help = 'Path to trained weights' )
135+ parser .add_argument ('--plot' , action = 'store_true' , default = False )
125136 parser .add_argument ('--imsize' ,
126137 type = int ,
127138 default = 8 ,
0 commit comments