Skip to content

Commit 40005ae

Browse files
committed
improved docs for PyTorch code
1 parent 82daf43 commit 40005ae

File tree

3 files changed

+94
-30
lines changed

3 files changed

+94
-30
lines changed

dlclive/dlclive.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,55 @@ class DLCLive:
4646
TensorFlow only. Optional ConfigProto for the TensorFlow session.
4747
4848
single_animal: bool, default=True
49-
PyTorch only.
49+
PyTorch only. If True, the predicted pose array returned by the runner will be
50+
(num_bodyparts, 3). As multi-animal pose estimation can be run with the PyTorch
51+
engine, setting this to False means the returned pose array will be of shape
52+
(num_detections, num_bodyparts, 3).
5053
5154
device: str, optional, default=None
52-
PyTorch only.
55+
PyTorch only. The device on which to run inference, e.g. "cpu", "cuda" or
56+
"cuda:0". If set to None or "auto", the device will be automatically selected
57+
based on CUDA availability.
5358
5459
top_down_config: dict, optional, default=None
60+
PyTorch only. Configuration settings for top-down pose estimation models. Must
61+
be provided when running top-down models and `top_down_dynamic` is None. The
62+
parameters in the dict will be given to the `TopDownConfig` class (in
63+
`dlclive/pose_estimation_pytorch/runner.py`). The `crop_size` does not need to
64+
be set, as it will be read from the model configuration file.
65+
Example parameters:
66+
>>> # Running a top-down model with basic parameters
67+
>>> top_down_config = {
68+
>>> "bbox_cutoff": 0.5, # min confidence score for a bbox to be used
69+
>>> "max_detections": 3, # max number of detections to return in a frame
70+
>>> }
71+
>>> # Running a top-down model with skip-frames
72+
>>> top_down_config = {
73+
>>> "bbox_cutoff": 0.5, # min confidence score for a bbox to be used
74+
>>> "max_detections": 3, # max number of detections to return in a frame
75+
>>> "skip_frames": { # only run the detector every 5 frames
76+
>>> "skip": 5, # number of frames to skip between detections
77+
>>> "margin": 5, # margin (in pixels) to use when generating bboxes
78+
>>> },
79+
>>> }
5580
5681
top_down_dynamic: dict, optional, default=None
82+
PyTorch only. Single animal only. Top-down models do not need a detector to be
83+
used for single animal pose estimation. This is equivalent to dynamic cropping
84+
in TensorFlow or for bottom-up models, but crops are resized to the input size
85+
required by the model. Pose estimation is never run on the full image. If no
86+
animal is detected, the image is split into N by M "patches", and we run pose
87+
estimation on the batch of patches. Pose is kept from the patch with the
88+
highest likelyhood. No need to provide the `top_down_crop_size` parameter, as it
89+
set using the model configuration file.
90+
The parameters (except "type") will be passed to the `TopDownDynamicCropper`
91+
class (in `dlclive/pose_estimation_pytorch/dynamic_cropping.py`
92+
93+
Example parameters:
94+
>>> top_down_dynamic = {
95+
>>> "type": "TopDownDynamicCropper",
96+
>>> "min_bbox_size": (50, 50),
97+
>>> }
5798
5899
cropping: list of int
59100
Cropping parameters in pixel number: [x1, x2, y1, y2]

dlclive/pose_estimation_pytorch/dynamic_cropping.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,19 @@ class TopDownDynamicCropper(DynamicCropper):
260260

261261
def __init__(
262262
self,
263-
top_down_crop_size: tuple[int, int],
264-
patch_counts: tuple[int, int],
265-
patch_overlap: int,
266-
min_bbox_size: tuple[int, int],
267-
threshold: float,
268-
margin: int,
263+
top_down_crop_size: tuple[int, int] = (256, 256),
264+
patch_counts: tuple[int, int] = (4, 3),
265+
patch_overlap: int = 50,
266+
min_bbox_size: tuple[int, int] = (100, 100),
267+
threshold: float = 0.6,
268+
margin: int = 10,
269269
min_hq_keypoints: int = 2,
270270
bbox_from_hq: bool = False,
271271
store_crops: bool = False,
272272
**kwargs,
273273
) -> None:
274274
super().__init__(threshold=threshold, margin=margin, **kwargs)
275+
self.top_down_crop_size = top_down_crop_size
275276
self.min_bbox_size = min_bbox_size
276277
self.min_hq_keypoints = min_hq_keypoints
277278
self.bbox_from_hq = bbox_from_hq
@@ -280,8 +281,7 @@ def __init__(
280281
self._patch_overlap = patch_overlap
281282
self._patches = []
282283
self._patch_offsets = []
283-
self._td_crop_size = top_down_crop_size
284-
self._td_ratio = self._td_crop_size[0] / self._td_crop_size[1]
284+
self._td_ratio = self.top_down_crop_size[0] / self.top_down_crop_size[1]
285285

286286
self.crop_history = []
287287
self.store_crops = store_crops
@@ -363,7 +363,7 @@ def update(self, pose: torch.Tensor) -> torch.Tensor:
363363
)
364364

365365
# offset and rescale the pose to the original image space
366-
out_w, out_h = self._td_crop_size
366+
out_w, out_h = self.top_down_crop_size
367367
offset_x, offset_y, w, h = self._crop
368368
scale_x, scale_y = w / out_w, h / out_h
369369
pose[..., 0] = (pose[..., 0] * scale_x) + offset_x
@@ -448,7 +448,7 @@ def _crop_bounding_box(
448448
The cropped and resized image.
449449
"""
450450
x1, y1, w, h = bbox
451-
out_w, out_h = self._td_crop_size
451+
out_w, out_h = self.top_down_crop_size
452452
return F.resized_crop(image, y1, x1, h, w, [out_h, out_w])
453453

454454
def _crop_patches(self, image: torch.Tensor) -> torch.Tensor:

dlclive/pose_estimation_pytorch/runner.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ class SkipFrames:
3232
then the detector will only be run every `skip` frames. Between frames where the
3333
detector is run, bounding boxes will be computed from the pose estimated in the
3434
previous frame.
35+
36+
Every `N` frames, the detector will be run to detect bounding boxes for individuals.
37+
In the "skipped" frames between the frames where the object detector is run, the
38+
bounding boxes will be computed from the poses estimated in the previous frame (with
39+
some margin added around the poses).
40+
41+
Attributes:
42+
skip: The number of frames to skip between each run of the detector.
43+
margin: The margin (in pixels) to use when generating bboxes
3544
"""
3645

3746
skip: int
@@ -78,20 +87,28 @@ class TopDownConfig:
7887
"""Configuration for top-down models.
7988
8089
Attributes:
90+
bbox_cutoff: The minimum score required for a bounding box to be considered.
91+
max_detections: The maximum number of detections to keep in a frame. If None,
92+
the `max_detections` will be set to the number of individuals in the model
93+
configuration file when `read_config` is called.
8194
skip_frames: If defined, the detector will only be run every
8295
`skip_frames.skip` frames.
8396
"""
8497

85-
bbox_cutoff: float
86-
max_detections: int
98+
bbox_cutoff: float = 0.6
99+
max_detections: int | None = 30
87100
crop_size: tuple[int, int] = (256, 256)
88101
skip_frames: SkipFrames | None = None
89102

90-
def read_config(self, detector_cfg: dict) -> None:
91-
crop = detector_cfg.get("data", {}).get("inference", {}).get("top_down_crop")
103+
def read_config(self, model_cfg: dict) -> None:
104+
crop = model_cfg.get("data", {}).get("inference", {}).get("top_down_crop")
92105
if crop is not None:
93106
self.crop_size = (crop["width"], crop["height"])
94107

108+
if self.max_detections is None:
109+
individuals = model_cfg.get("metadata", {}).get("individuals", [])
110+
self.max_detections = len(individuals)
111+
95112

96113
class PyTorchRunner(BaseRunner):
97114
"""PyTorch runner for live pose estimation using DeepLabCut-Live.
@@ -242,7 +259,7 @@ def load_model(self) -> None:
242259
self.model = self.model.half()
243260

244261
self.detector = None
245-
if raw_data.get("detector") is not None:
262+
if self.dynamic is None and raw_data.get("detector") is not None:
246263
self.detector = models.DETECTORS.build(self.cfg["detector"]["model"])
247264
self.detector.to(self.device)
248265
self.detector.load_state_dict(raw_data["detector"])
@@ -251,18 +268,23 @@ def load_model(self) -> None:
251268
if self.precision == "FP16":
252269
self.detector = self.detector.half()
253270

254-
if self.cfg["method"] == "td" and self.detector is None:
255-
crop_cfg = self.cfg["data"]["inference"]["top_down_crop"]
256-
top_down_crop_size = crop_cfg["width"], crop_cfg["height"]
257-
self.dynamic = dynamic_cropping.TopDownDynamicCropper(
258-
top_down_crop_size,
259-
patch_counts=(4, 3),
260-
patch_overlap=50,
261-
min_bbox_size=(250, 250),
262-
threshold=0.6,
263-
margin=25,
264-
min_hq_keypoints=2,
265-
bbox_from_hq=True,
271+
if self.top_down_config is None:
272+
self.top_down_config = TopDownConfig()
273+
274+
self.top_down_config.read_config(self.cfg)
275+
276+
if isinstance(self.dynamic, dynamic_cropping.TopDownDynamicCropper):
277+
crop = self.cfg["data"]["inference"].get("top_down_crop", {})
278+
w, h = crop.get("width", 256), crop.get("height", 256)
279+
self.dynamic.top_down_crop_size = w, h
280+
281+
if (
282+
self.cfg["method"] == "td"
283+
and self.detector is None
284+
and self.dynamic is None
285+
):
286+
raise ValueError(
287+
"Top-down models must either use a detector or a TopDownDynamicCropper."
266288
)
267289

268290
self.transform = v2.Compose(
@@ -283,9 +305,10 @@ def read_config(self) -> dict:
283305
def _prepare_top_down(
284306
self, frame: torch.Tensor, detections: dict[str, torch.Tensor]
285307
):
308+
"""Prepares a frame for top-down pose estimation."""
286309
bboxes, scores = detections["boxes"], detections["scores"]
287310
bboxes = bboxes[scores >= self.top_down_config.bbox_cutoff]
288-
if len(bboxes) > 0:
311+
if len(bboxes) > 0 and self.top_down_config.max_detections is not None:
289312
bboxes = bboxes[: self.top_down_config.max_detections]
290313

291314
crops = []

0 commit comments

Comments
 (0)