Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ To display the results you need to:
$ python sort.py --display
```

Added command-line argument --max_id_threshold

```bash
# Unlimited IDs (default behavior)
python sort.py

# Limited to 1000 IDs
python sort.py --max_id_threshold 1000
```

### Main Results

Expand Down
69 changes: 63 additions & 6 deletions sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage import io
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(self,bbox):
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1,0,0,0,1,0,0],[0,1,0,0,0,1,0],[0,0,1,0,0,0,1],[0,0,0,1,0,0,0], [0,0,0,0,1,0,0],[0,0,0,0,0,1,0],[0,0,0,0,0,0,1]])
self.kf.H = np.array([[1,0,0,0,0,0,0],[0,1,0,0,0,0,0],[0,0,1,0,0,0,0],[0,0,0,1,0,0,0]])

self.kf.R[2:,2:] *= 10.
self.kf.P[4:,4:] *= 1000. #give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
Expand Down Expand Up @@ -197,15 +197,60 @@ def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):


class Sort(object):
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3, max_id_threshold=None):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.max_id_threshold = max_id_threshold # None means unlimited IDs
self.trackers = []
self.frame_count = 0
# Add ID management for keeping track of currently active IDs
self.active_ids = set()
self.next_available_id = 1

def get_next_id(self):
"""
Get the next available ID, respecting the max_id_threshold if set
"""
if self.max_id_threshold is None:
# Unlimited IDs - just increment and reuse released IDs when possible
while self.next_available_id in self.active_ids:
self.next_available_id += 1

current_id = self.next_available_id
self.active_ids.add(current_id)
self.next_available_id += 1
return current_id
else:
# Limited IDs - respect the maximum id threshold
# Find the next available ID that isn't currently active
while self.next_available_id in self.active_ids and self.next_available_id <= self.max_id_threshold:
self.next_available_id += 1

if self.next_available_id > self.max_id_threshold:
# If we've exceeded threshold, find the first available ID from 1 to threshold
for i in range(1, self.max_id_threshold + 1):
if i not in self.active_ids:
self.next_available_id = i
break
else:
# If all IDs 1-threshold are active (rare case), wrap around
self.next_available_id = 1

current_id = self.next_available_id
self.active_ids.add(current_id)
self.next_available_id += 1
return current_id

def release_id(self, tracker_id):
"""
Remove an ID from active set when tracker is deleted
"""
if tracker_id in self.active_ids:
self.active_ids.remove(tracker_id)

def update(self, dets=np.empty((0, 5))):
"""
Expand Down Expand Up @@ -239,14 +284,22 @@ def update(self, dets=np.empty((0, 5))):
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i,:])
self.trackers.append(trk)

i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive
# Use managed ID system instead of the original logic
if not hasattr(trk, 'assigned_id'):
trk.assigned_id = self.get_next_id()

ret.append(np.concatenate((d,[trk.assigned_id])).reshape(1,-1))
i -= 1
# remove dead tracklet
if(trk.time_since_update > self.max_age):
# Release the ID when removing the tracker
if hasattr(trk, 'assigned_id'):
self.release_id(trk.assigned_id)
self.trackers.pop(i)
if(len(ret)>0):
return np.concatenate(ret)
Expand All @@ -265,6 +318,9 @@ def parse_args():
help="Minimum number of associated detections before track is initialised.",
type=int, default=3)
parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
parser.add_argument("--max_id_threshold",
help="Maximum tracking ID threshold. If not specified, IDs can grow indefinitely.",
type=int, default=None)
args = parser.parse_args()
return args

Expand All @@ -290,7 +346,8 @@ def parse_args():
for seq_dets_fn in glob.glob(pattern):
mot_tracker = Sort(max_age=args.max_age,
min_hits=args.min_hits,
iou_threshold=args.iou_threshold) #create instance of the SORT tracker
iou_threshold=args.iou_threshold,
max_id_threshold=args.max_id_threshold) #create instance of the SORT tracker
seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
seq = seq_dets_fn[pattern.find('*'):].split(os.path.sep)[0]

Expand Down Expand Up @@ -327,4 +384,4 @@ def parse_args():
print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))

if(display):
print("Note: to get real runtime results run without the option: --display")
print("Note: to get real runtime results run without the option: --display")