@@ -32,6 +32,15 @@ class SkipFrames:
3232 then the detector will only be run every `skip` frames. Between frames where the
3333 detector is run, bounding boxes will be computed from the pose estimated in the
3434 previous frame.
35+
36+ Every `N` frames, the detector will be run to detect bounding boxes for individuals.
37+ In the "skipped" frames between the frames where the object detector is run, the
38+ bounding boxes will be computed from the poses estimated in the previous frame (with
39+ some margin added around the poses).
40+
41+ Attributes:
42+ skip: The number of frames to skip between each run of the detector.
43+ margin: The margin (in pixels) to use when generating bboxes
3544 """
3645
3746 skip : int
@@ -78,20 +87,28 @@ class TopDownConfig:
7887 """Configuration for top-down models.
7988
8089 Attributes:
90+ bbox_cutoff: The minimum score required for a bounding box to be considered.
91+ max_detections: The maximum number of detections to keep in a frame. If None,
92+ the `max_detections` will be set to the number of individuals in the model
93+ configuration file when `read_config` is called.
8194 skip_frames: If defined, the detector will only be run every
8295 `skip_frames.skip` frames.
8396 """
8497
85- bbox_cutoff : float
86- max_detections : int
98+ bbox_cutoff : float = 0.6
99+ max_detections : int | None = 30
87100 crop_size : tuple [int , int ] = (256 , 256 )
88101 skip_frames : SkipFrames | None = None
89102
90- def read_config (self , detector_cfg : dict ) -> None :
91- crop = detector_cfg .get ("data" , {}).get ("inference" , {}).get ("top_down_crop" )
103+ def read_config (self , model_cfg : dict ) -> None :
104+ crop = model_cfg .get ("data" , {}).get ("inference" , {}).get ("top_down_crop" )
92105 if crop is not None :
93106 self .crop_size = (crop ["width" ], crop ["height" ])
94107
108+ if self .max_detections is None :
109+ individuals = model_cfg .get ("metadata" , {}).get ("individuals" , [])
110+ self .max_detections = len (individuals )
111+
95112
96113class PyTorchRunner (BaseRunner ):
97114 """PyTorch runner for live pose estimation using DeepLabCut-Live.
@@ -242,7 +259,7 @@ def load_model(self) -> None:
242259 self .model = self .model .half ()
243260
244261 self .detector = None
245- if raw_data .get ("detector" ) is not None :
262+ if self . dynamic is None and raw_data .get ("detector" ) is not None :
246263 self .detector = models .DETECTORS .build (self .cfg ["detector" ]["model" ])
247264 self .detector .to (self .device )
248265 self .detector .load_state_dict (raw_data ["detector" ])
@@ -251,18 +268,23 @@ def load_model(self) -> None:
251268 if self .precision == "FP16" :
252269 self .detector = self .detector .half ()
253270
254- if self .cfg ["method" ] == "td" and self .detector is None :
255- crop_cfg = self .cfg ["data" ]["inference" ]["top_down_crop" ]
256- top_down_crop_size = crop_cfg ["width" ], crop_cfg ["height" ]
257- self .dynamic = dynamic_cropping .TopDownDynamicCropper (
258- top_down_crop_size ,
259- patch_counts = (4 , 3 ),
260- patch_overlap = 50 ,
261- min_bbox_size = (250 , 250 ),
262- threshold = 0.6 ,
263- margin = 25 ,
264- min_hq_keypoints = 2 ,
265- bbox_from_hq = True ,
271+ if self .top_down_config is None :
272+ self .top_down_config = TopDownConfig ()
273+
274+ self .top_down_config .read_config (self .cfg )
275+
276+ if isinstance (self .dynamic , dynamic_cropping .TopDownDynamicCropper ):
277+ crop = self .cfg ["data" ]["inference" ].get ("top_down_crop" , {})
278+ w , h = crop .get ("width" , 256 ), crop .get ("height" , 256 )
279+ self .dynamic .top_down_crop_size = w , h
280+
281+ if (
282+ self .cfg ["method" ] == "td"
283+ and self .detector is None
284+ and self .dynamic is None
285+ ):
286+ raise ValueError (
287+ "Top-down models must either use a detector or a TopDownDynamicCropper."
266288 )
267289
268290 self .transform = v2 .Compose (
@@ -283,9 +305,10 @@ def read_config(self) -> dict:
283305 def _prepare_top_down (
284306 self , frame : torch .Tensor , detections : dict [str , torch .Tensor ]
285307 ):
308+ """Prepares a frame for top-down pose estimation."""
286309 bboxes , scores = detections ["boxes" ], detections ["scores" ]
287310 bboxes = bboxes [scores >= self .top_down_config .bbox_cutoff ]
288- if len (bboxes ) > 0 :
311+ if len (bboxes ) > 0 and self . top_down_config . max_detections is not None :
289312 bboxes = bboxes [: self .top_down_config .max_detections ]
290313
291314 crops = []
0 commit comments