11import argparse
22import time
33import numpy as np
4+
5+ # import matplotlib
6+ # matplotlib.use('agg')
47import matplotlib .pyplot as plt
58
69from dr_spaam .detector import Detector
@@ -13,25 +16,32 @@ def inference_time():
1316
1417 # inference time
1518 use_gpu = True
16- for use_drow in (True , False ):
17- ckpt = './ckpts/drow_e40.pth' if use_drow else './ckpts/dr_spaam_e40.pth'
18- detector = Detector (ckpt , original_drow = use_drow , gpu = use_gpu , stride = 1 )
19+ model_names = ("DR-SPAAM" , "DROW" , "DROW-T5" )
20+ ckpts = (
21+ "./ckpts/dr_spaam_e40.pth" ,
22+ "./ckpts/drow_e40.pth" ,
23+ "./ckpts/drow5_e40.pth"
24+ )
25+ for model_name , ckpt in zip (model_names , ckpts ):
26+ detector = Detector (model_name = model_name , ckpt_file = ckpt , gpu = use_gpu , stride = 1 )
1927 detector .set_laser_spec (angle_inc = np .radians (0.5 ), num_pts = 450 )
2028
2129 t_list = []
2230 for i in range (60 ):
31+ s = scans [i :i + 5 ] if model_name == "DROW-T5" else scans [i ]
2332 t0 = time .time ()
24- dets_xy , dets_cls , instance_mask = detector (scans [ i ] )
33+ dets_xy , dets_cls , instance_mask = detector (s )
2534 t_list .append (1e3 * (time .time () - t0 ))
2635
2736 t = np .array (t_list [10 :]).mean ()
2837 print ("inference time (model: %s, gpu: %s): %f ms (%.1f FPS)" % (
29- "DROW" if use_drow else "DR-SPAAM" , use_gpu , t , 1e3 / t ))
38+ model_name , use_gpu , t , 1e3 / t ))
3039
3140
3241def play_sequence ():
3342 # scans
3443 seq_name = './data/DROWv2-data/test/run_t_2015-11-26-11-22-03.bag.csv'
44+ # seq_name = './data/DROWv2-data/val/run_2015-11-26-15-52-55-k.bag.csv'
3545 scans_data = np .genfromtxt (seq_name , delimiter = ',' )
3646 scans_t = scans_data [:, 1 ]
3747 scans = scans_data [:, 2 :]
@@ -45,7 +55,7 @@ def play_sequence():
4555
4656 # detector
4757 ckpt = './ckpts/dr_spaam_e40.pth'
48- detector = Detector (ckpt , original_drow = False , gpu = True , stride = 1 )
58+ detector = Detector (model_name = "DR-SPAAM" , ckpt_file = ckpt , gpu = True , stride = 1 )
4959 detector .set_laser_spec (angle_inc = np .radians (0.5 ), num_pts = 450 )
5060
5161 # scanner location
@@ -54,7 +64,7 @@ def play_sequence():
5464 xy_scanner = np .stack (xy_scanner , axis = 1 )
5565
5666 # plot
57- fig = plt .figure (figsize = (9 , 6 ))
67+ fig = plt .figure (figsize = (10 , 10 ))
5868 ax = fig .add_subplot (111 )
5969
6070 _break = False
@@ -67,11 +77,12 @@ def p(event):
6777 # video sequence
6878 odo_idx = 0
6979 for i in range (len (scans )):
80+ # for i in range(0, len(scans), 20):
7081 plt .cla ()
7182
7283 ax .set_aspect ('equal' )
7384 ax .set_xlim (- 15 , 15 )
74- ax .set_ylim (- 5 , 15 )
85+ ax .set_ylim (- 15 , 15 )
7586
7687 # ax.set_title('Frame: %s' % i)
7788 ax .set_title ('Press any key to exit.' )
@@ -104,10 +115,11 @@ def p(event):
104115 for j in range (len (dets_xy )):
105116 if dets_cls [j ] < cls_thresh :
106117 continue
107- c = plt .Circle (dets_xy_rot [j ], radius = 0.5 , color = 'r' , fill = False )
118+ # c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False)
119+ c = plt .Circle (dets_xy_rot [j ], radius = 0.5 , color = 'r' , fill = False , linewidth = 2 )
108120 ax .add_artist (c )
109121
110- # plt.savefig('/home/jia/tmp_imgs/dets /frame_%04d.png' % i)
122+ # plt.savefig('/home/dan/tmp/det_img /frame_%04d.png' % i)
111123
112124 plt .pause (0.001 )
113125
@@ -118,7 +130,7 @@ def p(event):
118130def play_sequence_with_tracking ():
119131 # scans
120132 seq_name = './data/DROWv2-data/train/lunch_2015-11-26-12-04-23.bag.csv'
121- seq0 , seq1 = 107000 , 109357
133+ seq0 , seq1 = 109170 , 109360
122134 scans , scans_t = [], []
123135 with open (seq_name ) as f :
124136 for line in f :
@@ -142,7 +154,7 @@ def play_sequence_with_tracking():
142154
143155 # detector
144156 ckpt = './ckpts/dr_spaam_e40.pth'
145- detector = Detector (ckpt , original_drow = False , gpu = True , stride = 1 )
157+ detector = Detector (model_name = "DR-SPAAM" , ckpt_file = ckpt , gpu = True , stride = 1 , tracking = True )
146158 detector .set_laser_spec (angle_inc = np .radians (0.5 ), num_pts = 450 )
147159
148160 # scanner location
@@ -151,7 +163,7 @@ def play_sequence_with_tracking():
151163 xy_scanner = np .stack (xy_scanner , axis = 1 )
152164
153165 # plot
154- fig = plt .figure (figsize = (9 , 6 ))
166+ fig = plt .figure (figsize = (6 , 8 ))
155167 ax = fig .add_subplot (111 )
156168
157169 _break = False
@@ -167,7 +179,7 @@ def p(event):
167179 plt .cla ()
168180
169181 ax .set_aspect ('equal' )
170- ax .set_xlim (- 15 , 15 )
182+ ax .set_xlim (- 10 , 5 )
171183 ax .set_ylim (- 5 , 15 )
172184
173185 # ax.set_title('Frame: %s' % i)
@@ -193,15 +205,15 @@ def p(event):
193205 ax .scatter (scan_x , scan_y , s = 1 , c = 'blue' )
194206
195207 # inference
196- dets_xy , dets_cls , instance_mask = detector (scan , tracking = True )
208+ dets_xy , dets_cls , instance_mask = detector (scan )
197209
198210 # plot detection
199211 dets_xy_rot = np .matmul (dets_xy , odo_rot .T )
200212 cls_thresh = 0.3
201213 for j in range (len (dets_xy )):
202214 if dets_cls [j ] < cls_thresh :
203215 continue
204- c = plt .Circle (dets_xy_rot [j ], radius = 0.5 , color = 'r' , fill = False )
216+ c = plt .Circle (dets_xy_rot [j ], radius = 0.5 , color = 'r' , fill = False , linewidth = 2 )
205217 ax .add_artist (c )
206218
207219 # plot track
@@ -210,9 +222,9 @@ def p(event):
210222 for t , tc in zip (tracks , tracks_cls ):
211223 if tc >= cls_thresh and len (t ) > 1 :
212224 t_rot = np .matmul (t , odo_rot .T )
213- ax .plot (t_rot [:, 0 ], t_rot [:, 1 ], color = 'g' )
225+ ax .plot (t_rot [:, 0 ], t_rot [:, 1 ], color = 'g' , linewidth = 2 )
214226
215- # plt.savefig('/home/jia/tmp_imgs/tracks /frame_%05d .png' % i)
227+ # plt.savefig('/home/dan/tmp/track3_img /frame_%04d .png' % i)
216228
217229 plt .pause (0.001 )
218230
0 commit comments