Skip to content

Commit d3939cb

Browse files
author
Kent Sommer
committed
Added success rate measure
1 parent 8032d93 commit d3939cb

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import argparse
23

34
import matplotlib.pyplot as plt
@@ -15,9 +16,10 @@
1516
from generators.obstacle_gen import *
1617

1718

18-
def main(config, n_domains=100, max_obs=10,
19+
def main(config, n_domains=5000, max_obs=40,
1920
max_obs_size=None, n_traj=1, n_actions=8):
20-
21+
# Correct vs total:
22+
correct, total = 0.0, 0.0
2123
# Automatic swith of GPU mode if available
2224
use_GPU = torch.cuda.is_available()
2325
vin = torch.load(config.weights)
@@ -95,7 +97,15 @@ def main(config, n_domains=100, max_obs=10,
9597
pred_traj[j+1:,1] = nc
9698
break
9799
# Plot optimal and predicted path (also start, end)
98-
visualize(G.image.T, states_xy[i], pred_traj)
100+
if pred_traj[-1, 0] == goal[0] and pred_traj[-1, 1] == goal[1]:
101+
correct += 1
102+
total += 1
103+
if config.plot == True:
104+
visualize(G.image.T, states_xy[i], pred_traj)
105+
sys.stdout.write("\r" + str(int((float(dom)/n_domains) * 100.0)) + "%")
106+
sys.stdout.flush()
107+
sys.stdout.write("\n")
108+
print('Rollout Accuracy: {:.2f}%'.format(100*(correct/total)))
99109

100110

101111
def visualize(dom, states_xy, pred_traj):
@@ -122,6 +132,7 @@ def visualize(dom, states_xy, pred_traj):
122132
type=str,
123133
default='trained/vin_8x8.pth',
124134
help='Path to trained weights')
135+
parser.add_argument('--plot', action='store_true', default=False)
125136
parser.add_argument('--imsize',
126137
type=int,
127138
default=8,

0 commit comments

Comments
 (0)