@@ -68,6 +68,7 @@ def __init__(
6868 nms_class_iou_thr : float = 0.7 ,
6969 nms_conf_thr : float = 0.5 ,
7070 with_cats : bool = True ,
71+ with_velocities : bool = False ,
7172 bbox_affinity_weight : float = 0.5 ,
7273 ) -> None :
7374 """Creates an instance of the class.
@@ -83,10 +84,12 @@ def __init__(
8384 another detection.
8485 nms_class_iou_thr (float): Maximum IoU of a high score detection
8586 with another of a different class.
87+ nms_conf_thr (float): Confidence threshold for NMS.
8688 with_cats (bool): If to consider category information for
8789 tracking (i.e. all detections within a track must have
8890 consistent category labels).
89- nms_conf_thr (float): Confidence threshold for NMS.
91+ with_velocities (bool): If to use predicted velocities for
92+ matching.
9093 bbox_affinity_weight (float): Weight of bbox affinity in the
9194 overall affinity score.
9295 """
@@ -98,6 +101,7 @@ def __init__(
98101 self .nms_class_iou_thr = nms_class_iou_thr
99102 self .nms_conf_thr = nms_conf_thr
100103 self .with_cats = with_cats
104+ self .with_velocities = with_velocities
101105 self .bbox_affinity_weight = bbox_affinity_weight
102106 self .feat_affinity_weight = 1 - bbox_affinity_weight
103107
@@ -110,7 +114,8 @@ def _filter_detections(
110114 scores_3d : Tensor ,
111115 class_ids : Tensor ,
112116 embeddings : Tensor ,
113- ) -> tuple [Tensor , Tensor , Tensor , Tensor , Tensor , Tensor , Tensor ]:
117+ velocities : Tensor | None = None ,
118+ ) -> tuple [Tensor , Tensor , Tensor , Tensor , Tensor , Tensor , Tensor , Tensor ]:
114119 """Remove overlapping objects across classes via nms.
115120
116121 Args:
@@ -121,6 +126,7 @@ def _filter_detections(
121126 scores_3d (Tensor): [N,] Tensor of 3D confidence scores.
122127 class_ids (Tensor): [N,] Tensor of class ids.
123128 embeddings (Tensor): [N, C] tensor of appearance embeddings.
129+ velocities (Tensor | None): [N, 3] Tensor of velocities.
124130
125131 Returns:
126132 tuple[Tensor]: filtered detections, scores, class_ids,
@@ -142,6 +148,10 @@ def _filter_detections(
142148 detections_3d [inds ],
143149 scores_3d [inds ],
144150 )
151+
152+ if velocities is not None :
153+ velocities = velocities [inds ]
154+
145155 valids = embeddings .new_ones ((len (detections ),), dtype = torch .bool )
146156
147157 ious = bbox_iou (detections , detections )
@@ -158,25 +168,32 @@ def _filter_detections(
158168
159169 if (ious [i , :i ] > thr ).any ():
160170 valids [i ] = False
171+
161172 detections = detections [valids ]
162173 scores = scores [valids ]
163174 detections_3d = detections_3d [valids ]
164175 scores_3d = scores_3d [valids ]
165176 class_ids = class_ids [valids ]
166177 embeddings = embeddings [valids ]
178+
179+ if velocities is not None :
180+ velocities = velocities [valids ]
181+
167182 return (
168183 detections ,
169184 scores ,
170185 detections_3d ,
171186 scores_3d ,
172187 class_ids ,
173188 embeddings ,
189+ velocities ,
174190 inds [valids ],
175191 )
176192
177- @staticmethod
178193 def depth_ordering (
194+ self ,
179195 obsv_boxes_3d : Tensor ,
196+ obsv_velocities : Tensor | None ,
180197 memory_boxes_3d_predict : Tensor ,
181198 memory_boxes_3d : Tensor ,
182199 memory_velocities : Tensor ,
@@ -197,11 +214,11 @@ def depth_ordering(
197214
198215 # Moving distance should be aligned
199216 motion_weight_list = []
200- obsv_velocities = (
217+ moving_dist = (
201218 obsv_boxes_3d [:, :3 , None ]
202219 - memory_boxes_3d [:, :3 , None ].transpose (2 , 0 )
203220 ).transpose (1 , 2 )
204- for v in obsv_velocities :
221+ for v in moving_dist :
205222 motion_weight_list .append (
206223 F .pairwise_distance ( # pylint: disable=not-callable
207224 v , memory_velocities [:, :3 ]
@@ -210,22 +227,41 @@ def depth_ordering(
210227 motion_weight = torch .cat (motion_weight_list , dim = 0 )
211228 motion_weight = torch .exp (- torch .div (motion_weight , 5.0 ))
212229
213- # Moving direction should be aligned
214- # Set to 0.5 when two vector not within +-90 degree
215- cos_sim_list = []
216- obsv_direct = (
217- obsv_boxes_3d [:, :2 , None ]
218- - memory_boxes_3d [:, :2 , None ].transpose (2 , 0 )
219- ).transpose (1 , 2 )
220- for d in obsv_direct :
221- cos_sim_list .append (
222- F .cosine_similarity ( # pylint: disable=not-callable
223- d , memory_velocities [:, :2 ]
224- ).unsqueeze (0 )
230+ # Velocity scores
231+ if self .with_velocities :
232+ assert (
233+ obsv_velocities is not None
234+ ), "Please provide velocities if with_velocities=True!"
235+
236+ velsim_weight_list = []
237+ obsvvv_velocities = obsv_velocities .unsqueeze (1 ).expand_as (
238+ moving_dist
225239 )
226- cos_sim = torch .cat (cos_sim_list , dim = 0 )
227- cos_sim = torch .add (cos_sim , 1.0 )
228- cos_sim = torch .div (cos_sim , 2.0 )
240+ for v in obsvvv_velocities :
241+ velsim_weight_list .append (
242+ F .pairwise_distance (
243+ v , memory_velocities [:, - 3 :]
244+ ).unsqueeze (0 )
245+ )
246+ velsim_weight = torch .cat (velsim_weight_list , dim = 0 )
247+ cos_sim = torch .exp (- velsim_weight / 5.0 )
248+ else :
249+ # Moving direction should be aligned
250+ # Set to 0.5 when two vector not within +-90 degree
251+ cos_sim_list = []
252+ obsv_direct = (
253+ obsv_boxes_3d [:, :2 , None ]
254+ - memory_boxes_3d [:, :2 , None ].transpose (2 , 0 )
255+ ).transpose (1 , 2 )
256+ for d in obsv_direct :
257+ cos_sim_list .append (
258+ F .cosine_similarity ( # pylint: disable=not-callable
259+ d , memory_velocities [:, :2 ]
260+ ).unsqueeze (0 )
261+ )
262+ cos_sim = torch .cat (cos_sim_list , dim = 0 )
263+ cos_sim = torch .add (cos_sim , 1.0 )
264+ cos_sim = torch .div (cos_sim , 2.0 )
229265
230266 scores_depth = (
231267 cos_sim * centroid_weight + (1.0 - cos_sim ) * motion_weight
@@ -242,6 +278,7 @@ def __call__(
242278 detection_scores_3d : Tensor ,
243279 detection_class_ids : Tensor ,
244280 detection_embeddings : Tensor ,
281+ obs_velocities : Tensor | None = None ,
245282 memory_boxes_3d : Tensor | None = None ,
246283 memory_track_ids : Tensor | None = None ,
247284 memory_class_ids : Tensor | None = None ,
@@ -260,6 +297,7 @@ def __call__(
260297 detection_scores_3d (Tensor): [N,] confidence scores in 3D.
261298 detection_class_ids (Tensor): [N,] class indices.
262299 detection_embeddings (Tensor): [N, C] appearance embeddings.
300+ obs_velocities (Tensor | None): [N, 3] velocities of detections.
263301 memory_boxes_3d (Tensor): [M, 7] boxes in memory.
264302 memory_track_ids (Tensor): [M,] track ids in memory.
265303 memory_class_ids (Tensor): [M,] class indices in memory.
@@ -280,6 +318,7 @@ def __call__(
280318 detection_scores_3d ,
281319 detection_class_ids ,
282320 detection_embeddings ,
321+ obs_velocities ,
283322 permute_inds ,
284323 ) = self ._filter_detections (
285324 detections ,
@@ -289,6 +328,7 @@ def __call__(
289328 detection_scores_3d ,
290329 detection_class_ids ,
291330 detection_embeddings ,
331+ obs_velocities ,
292332 )
293333
294334 if with_depth_confidence :
@@ -324,6 +364,7 @@ def __call__(
324364 # Depth Ordering
325365 scores_depth = self .depth_ordering (
326366 detections_3d ,
367+ obs_velocities ,
327368 memory_boxes_3d_predict ,
328369 memory_boxes_3d ,
329370 memory_velocities ,
0 commit comments