diff --git a/yolox/tracker/basetrack.py b/yolox/tracker/basetrack.py index 4fe22336..a5a2a643 100644 --- a/yolox/tracker/basetrack.py +++ b/yolox/tracker/basetrack.py @@ -10,31 +10,35 @@ class TrackState(object): class BaseTrack(object): - _count = 0 + def __init__(self, count_gen=None): + self._count = 0 + self.count_gen = count_gen - track_id = 0 - is_activated = False - state = TrackState.New + self.track_id = 0 + self.is_activated = False + self.state = TrackState.New - history = OrderedDict() - features = [] - curr_feature = None - score = 0 - start_frame = 0 - frame_id = 0 - time_since_update = 0 + self.history = OrderedDict() + self.features = [] + self.curr_feature = None + self.score = 0 + self.start_frame = 0 + self.frame_id = 0 + self.time_since_update = 0 - # multi-camera - location = (np.inf, np.inf) + # multi-camera + self.location = (np.inf, np.inf) @property def end_frame(self): return self.frame_id - @staticmethod - def next_id(): - BaseTrack._count += 1 - return BaseTrack._count + def next_id(self): + if self.count_gen: + self._count = self.count_gen.__next__() + else: + self._count += 1 + return self._count def activate(self, *args): raise NotImplementedError diff --git a/yolox/tracker/byte_tracker.py b/yolox/tracker/byte_tracker.py index 2d004599..b9609e9a 100644 --- a/yolox/tracker/byte_tracker.py +++ b/yolox/tracker/byte_tracker.py @@ -5,6 +5,7 @@ import copy import torch import torch.nn.functional as F +import itertools from .kalman_filter import KalmanFilter from yolox.tracker import matching @@ -12,8 +13,8 @@ class STrack(BaseTrack): shared_kalman = KalmanFilter() - def __init__(self, tlwh, score): - + def __init__(self, tlwh, score, count_gen=None): + super().__init__(count_gen) # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float) self.kalman_filter = None @@ -155,6 +156,7 @@ def __init__(self, args, frame_rate=30): self.buffer_size = int(frame_rate / 30.0 * args.track_buffer) self.max_time_lost = self.buffer_size self.kalman_filter = KalmanFilter() + self.track_count_gen = itertools.count(start=1) def update(self, output_results, img_info, img_size): self.frame_id += 1 @@ -186,7 +188,7 @@ def update(self, output_results, img_info, img_size): if len(dets) > 0: '''Detections''' - detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for + detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s, self.track_count_gen) for (tlbr, s) in zip(dets, scores_keep)] else: detections = [] @@ -223,7 +225,7 @@ def update(self, output_results, img_info, img_size): # association the untrack to the low score detections if len(dets_second) > 0: '''Detections''' - detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for + detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, self.track_count_gen) for (tlbr, s) in zip(dets_second, scores_second)] else: detections_second = []