|
28 | 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
29 | 29 |
|
30 | 30 | import logging |
| 31 | + |
31 | 32 | from pathlib import Path |
32 | 33 | from typing import Dict, List, Optional |
33 | 34 |
|
34 | 35 | import jde |
35 | 36 | import torch |
| 37 | + |
36 | 38 | from jde.models import Darknet |
37 | 39 | from jde.tracker import matching |
38 | 40 | from jde.tracker.basetrack import TrackState |
@@ -78,12 +80,12 @@ def __init__(self, device: str, **kwargs): |
78 | 80 | } |
79 | 81 |
|
80 | 82 | self.model_configs = { |
81 | | - "iou_thres": float(kwargs["iou_thres"]), |
82 | | - "conf_thres": float(kwargs["conf_thres"]), |
83 | | - "nms_thres": float(kwargs["nms_thres"]), |
84 | | - "min_box_area": int(kwargs["min_box_area"]), |
85 | | - "track_buffer": int(kwargs["track_buffer"]), |
86 | | - "frame_rate": float(kwargs["frame_rate"]), |
| 83 | + "iou_thres": float(kwargs.get("iou_thres", 0.5)), |
| 84 | + "conf_thres": float(kwargs.get("conf_thres", 0.5)), |
| 85 | + "nms_thres": float(kwargs.get("nms_thres", 0.4)), |
| 86 | + "min_box_area": int(kwargs.get("min_box_area", 200)), |
| 87 | + "track_buffer": int(kwargs.get("track_buffer", 30)), |
| 88 | + "frame_rate": float(kwargs.get("frame_rate", 30)), |
87 | 89 | } |
88 | 90 | self.max_time_on_hold = int( |
89 | 91 | self.model_configs["frame_rate"] / 30.0 * self.model_configs["track_buffer"] |
@@ -116,6 +118,15 @@ def __init__(self, device: str, **kwargs): |
116 | 118 | self.logger.level = kwargs["logging_level"] |
117 | 119 | # logging.DEBUG |
118 | 120 |
|
| 121 | + if kwargs.get("hyper_params", {}).get("update", False): |
| 122 | + hyper_params = { |
| 123 | + "conf_threshold": kwargs.get("hyper_params", {}).get( |
| 124 | + "conf_threshold", None |
| 125 | + ), |
| 126 | + "max_dets": kwargs.get("hyper_params", {}).get("max_dets", None), |
| 127 | + } |
| 128 | + self._apply_infer_overrides(hyper_params) |
| 129 | + |
119 | 130 | # reset member variables to use over a sequence of frame |
120 | 131 | self.reset() |
121 | 132 |
|
@@ -210,8 +221,7 @@ def _feature_pyramid_to_output( |
210 | 221 | return {"tlwhs": online_tlwhs, "ids": online_ids} |
211 | 222 |
|
212 | 223 | def _apply_infer_overrides(self, overrides: Dict): |
213 | | - |
214 | | - if "conf_threshold" in overrides: |
| 224 | + if overrides.get("conf_threshold") is not None: |
215 | 225 | self.model_configs["conf_thres"] = float(overrides["conf_threshold"]) |
216 | 226 |
|
217 | 227 | @torch.no_grad() |
@@ -337,9 +347,7 @@ def _jde_process(self, pred, org_img_size: tuple, input_img_size: tuple): |
337 | 347 |
|
338 | 348 | detections = [detections[i] for i in u_detection] |
339 | 349 | # detections is now a list of the unmatched detections |
340 | | - r_tracked_stracks = ( |
341 | | - [] |
342 | | - ) # This is container for stracks which were tracked till the |
| 350 | + r_tracked_stracks = [] # This is container for stracks which were tracked till the |
343 | 351 | # previous frame but no detection was found for it in the current frame |
344 | 352 | for i in u_track: |
345 | 353 | if track_candidates_pool[i].state == TrackState.Tracked: |
|
0 commit comments