Skip to content

Commit 82daf43

Browse files
committed
improved docs
1 parent 4c07e04 commit 82daf43

32 files changed

+192
-124
lines changed

dlclive/benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,6 @@ def benchmark_videos(
549549
)
550550

551551
while True:
552-
553552
ret, frame = cap.read()
554553
if not ret:
555554
break
@@ -656,8 +655,6 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
656655
writer.writerow(row)
657656

658657

659-
660-
661658
import argparse
662659
import os
663660

dlclive/benchmark_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def analyze_video(
214214
)
215215

216216
while True:
217-
218217
ret, frame = cap.read()
219218
if not ret:
220219
break

dlclive/benchmark_tf.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def benchmark(
305305

306306
iterator = range(n_frames) if (print_rate) or (display) else tqdm(range(n_frames))
307307
for i in iterator:
308-
309308
ret, frame = cap.read()
310309

311310
if not ret:
@@ -321,7 +320,6 @@ def benchmark(
321320
inf_times[i] = time.time() - start_pose
322321

323322
if save_video:
324-
325323
if colors is None:
326324
all_colors = getattr(cc, cmap)
327325
colors = [
@@ -399,15 +397,13 @@ def benchmark(
399397
vwriter.release()
400398

401399
if save_poses:
402-
403400
cfg_path = os.path.normpath(f"{model_path}/pose_cfg.yaml")
404401
ruamel_file = ruamel.yaml.YAML()
405402
dlc_cfg = ruamel_file.load(open(cfg_path, "r"))
406403
bodyparts = dlc_cfg["all_joints_names"]
407404
poses = np.array(poses)
408405

409406
if use_pandas:
410-
411407
poses = poses.reshape((poses.shape[0], poses.shape[1] * poses.shape[2]))
412408
pdindex = pd.MultiIndex.from_product(
413409
[bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"]
@@ -426,7 +422,6 @@ def benchmark(
426422
pose_df.to_hdf(out_dlc_file, key="df_with_missing", mode="w")
427423

428424
else:
429-
430425
out_vid_base = os.path.basename(video_path)
431426
out_dlc_file = os.path.normpath(
432427
f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_POSES.npy"
@@ -614,14 +609,12 @@ def benchmark_videos(
614609
# loop over videos
615610

616611
for v in video_path:
617-
618612
# initialize full inference times
619613

620614
inf_times = []
621615
im_size_out = []
622616

623617
for i in range(len(resize)):
624-
625618
print(f"\nRun {i+1} / {len(resize)}\n")
626619

627620
this_inf_times, this_im_size, TFGPUinference, meta = benchmark(

dlclive/core/inferenceutils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,6 @@ def assemble(self, chunk_size=1, n_processes=None):
841841
# work nicely with the GUI or interactive sessions.
842842
# In that case, we fall back to the serial assembly.
843843
if chunk_size == 0 or multiprocessing.get_start_method() == "spawn":
844-
845844
for i, data_dict in enumerate(tqdm(self)):
846845
assemblies, unique = self._assemble(data_dict, i)
847846
if assemblies:

dlclive/dlclive.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,41 @@ class DLCLive:
2828
-----------
2929
3030
model_path: Path
31-
Full path to exported model file
31+
Full path to exported model (created when `deeplabcut.export_model(...)` was
32+
called). For PyTorch models, this is a single model file. For TensorFlow models,
33+
this is a directory containing the model snapshots.
3234
3335
model_type: string, optional
34-
which model to use: 'pytorch' or 'onnx' for exported snapshot
36+
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
37+
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
38+
39+
precision: string, optional
40+
Precision of model weights, for model_type "pytorch" and "tensorrt". Options
41+
are, for different model_types:
42+
"pytorch": {"FP32", "FP16"}
43+
"tensorrt": {"FP32", "FP16", "INT8"}
3544
3645
tf_config:
46+
TensorFlow only. Optional ConfigProto for the TensorFlow session.
3747
48+
single_animal: bool, default=True
49+
PyTorch only.
3850
39-
precision: string, optional
40-
precision of model weights, for model_type='onnx' or 'pytorch'. Can be 'FP32'
41-
(default) or 'FP16'
51+
device: str, optional, default=None
52+
PyTorch only.
53+
54+
top_down_config: dict, optional, default=None
55+
56+
top_down_dynamic: dict, optional, default=None
4257
4358
cropping: list of int
44-
cropping parameters in pixel number: [x1, x2, y1, y2] #A: Maybe this is the
45-
dynamic cropping of each frame to speed of processing, so instead of analyzing
46-
the whole frame, it analyzes only the part of the frame where the animal is
47-
48-
dynamic: triple containing (state, detectiontreshold, margin) #A: margin adds some
49-
space so the 'bbox' isn't too narrow around the animal'. First key points are
50-
predicted, then dynamic cropping is performed to 'single out' the animal, and
51-
then pose is estimated, we think.
59+
Cropping parameters in pixel number: [x1, x2, y1, y2]
60+
61+
dynamic: triple containing (state, detectiontreshold, margin)
5262
If the state is true, then dynamic cropping will be performed. That means that
5363
if an object is detected (i.e. any body part > detectiontreshold), then object
5464
boundaries are computed according to the smallest/largest x position and
55-
smallest/largest y position of all body parts. This window is expanded by the
65+
smallest/largest y position of all body parts. This window is expanded by the
5666
margin and from then on only the posture within this crop is analyzed (until the
5767
object is lost, i.e. <detectiontreshold). The current position is utilized for
5868
updating the crop window for the next frame (this is why the margin is important
@@ -63,8 +73,7 @@ class DLCLive:
6373
For example, resize=0.5 will downsize both the height and width of the image by
6474
a factor of 2.
6575
66-
processor: dlc pose processor object, optional #A: this is possibly the 'predictor'
67-
- or is it what enables use on jetson boards?
76+
processor: dlc pose processor object, optional
6877
User-defined processor object. Must contain two methods: process and save.
6978
The 'process' method takes in a pose, performs some processing, and returns
7079
processed pose.
@@ -80,12 +89,19 @@ class DLCLive:
8089
boolean flag to convert frames from BGR to RGB color scheme
8190
8291
display: bool, optional
83-
Display frames with DeepLabCut labels?
92+
Open a display to show predicted pose in frames with DeepLabCut labels.
8493
This is useful for testing model accuracy and cropping parameters, but it is
8594
very slow.
8695
96+
pcutoff: float, default=0.5
97+
Only used when display=True. The score threshold for displaying a bodypart in
98+
the display.
99+
100+
display_radius: int, default=3
101+
Only used when display=True. Radius for keypoint display in pixels, default=3
102+
87103
display_cmap: str, optional
88-
String indicating the Matplotlib colormap to use.
104+
Only used when display=True. String indicating the Matplotlib colormap to use.
89105
"""
90106

91107
PARAMETERS = (
@@ -103,33 +119,36 @@ def __init__(
103119
self,
104120
model_path: str | Path,
105121
model_type: str = "base",
106-
# tf_config: Any = None,
107122
precision: str = "FP32",
108-
# single_animal: bool = True,
109-
# device: str | None = None,
123+
tf_config: Any = None,
124+
single_animal: bool = True,
125+
device: str | None = None,
126+
top_down_config: dict | None = None,
127+
top_down_dynamic: dict | None = None,
110128
cropping: list[int] | None = None,
111129
dynamic: tuple[bool, float, float] = (False, 0.5, 10),
112130
resize: float | None = None,
113131
convert2rgb: bool = True,
114132
processor: Processor | None = None,
115133
display: bool | Display = False,
116134
pcutoff: float = 0.5,
117-
# bbox_cutoff: float = 0.6,
118-
# max_detections: int = 1,
119135
display_radius: int = 3,
120136
display_cmap: str = "bmy",
121-
**kwargs,
122137
):
123138
self.path = Path(model_path)
124139
self.runner: BaseRunner = factory.build_runner(
125140
model_type,
126141
model_path,
127-
**kwargs,
142+
precision=precision,
143+
tf_config=tf_config,
144+
single_animal=single_animal,
145+
device=device,
146+
dynamic=top_down_dynamic,
147+
top_down_config=top_down_config,
128148
)
129149
self.is_initialized = False
130150

131151
self.model_type = model_type
132-
self.precision = precision
133152
self.cropping = cropping
134153
self.dynamic = dynamic
135154
self.dynamic_cropping = None

dlclive/factory.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,56 @@
1-
21
"""Factory to build runners for DeepLabCut-Live inference"""
2+
from __future__ import annotations
3+
34
from pathlib import Path
5+
from typing import Literal
46

57
from dlclive.core.runner import BaseRunner
68

79

810
def build_runner(
9-
model_type: str,
11+
model_type: Literal["pytorch", "tensorflow", "base", "tensorrt", "lite"],
1012
model_path: str | Path,
1113
**kwargs,
1214
) -> BaseRunner:
1315
"""
1416
1517
Parameters
1618
----------
17-
model_type
18-
model_path
19-
kwargs
19+
model_type: str, optional
20+
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
21+
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
22+
model_path: str, Path
23+
Full path to exported model (created when `deeplabcut.export_model(...)` was
24+
called). For PyTorch models, this is a single model file. For TensorFlow models,
25+
this is a directory containing the model snapshots.
26+
27+
kwargs: dict, optional
28+
PyTorch Engine Kwargs:
29+
30+
TensorFlow Engine Kwargs:
2031
2132
Returns
2233
-------
2334
2435
"""
2536
if model_type.lower() == "pytorch":
2637
from dlclive.pose_estimation_pytorch.runner import PyTorchRunner
27-
return PyTorchRunner(model_path, **kwargs)
38+
39+
valid = {"device", "precision", "single_animal", "dynamic", "top_down_config"}
40+
return PyTorchRunner(model_path, **filter_keys(valid, kwargs))
2841

2942
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
3043
from dlclive.pose_estimation_tensorflow.runner import TensorFlowRunner
31-
return TensorFlowRunner(model_path, model_type, **kwargs)
44+
45+
if model_type.lower() == "tensorflow":
46+
model_type = "base"
47+
48+
valid = {"tf_config", "precision"}
49+
return TensorFlowRunner(model_path, model_type, **filter_keys(valid, kwargs))
3250

3351
raise ValueError(f"Unknown model type: {model_type}")
52+
53+
54+
def filter_keys(valid: set[str], kwargs: dict) -> dict:
55+
"""Filters the keys in kwargs, only keeping those in valid."""
56+
return {k: v for k, v in kwargs.items() if k in valid}

dlclive/graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ def get_output_tensors(graph):
106106

107107

108108
def get_input_tensor(graph):
109-
110109
input_tensor = str(graph.get_operations()[0].name) + ":0"
111110
return input_tensor
112111

113112

114-
def extract_graph(graph, tf_config=None) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]:
113+
def extract_graph(
114+
graph, tf_config=None
115+
) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]:
115116
"""
116117
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs
117118

dlclive/live_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def analyze_live_video(
197197
]
198198

199199
if save_video:
200-
201200
# Define output video path
202201
output_video_path = os.path.join(
203202
save_dir, f"{experiment_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
@@ -217,7 +216,6 @@ def analyze_live_video(
217216
)
218217

219218
while True:
220-
221219
ret, frame = cap.read()
222220
if not ret:
223221
break

dlclive/pose_estimation_pytorch/data/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def top_down_crop(
109109

110110
# crop the pixels we care about
111111
image_crop = np.zeros((crop_h, crop_w, c), dtype=image.dtype)
112-
image_crop[pad_top:pad_top + h, pad_left:pad_left + w] = image[y1:y2, x1:x2]
112+
image_crop[pad_top : pad_top + h, pad_left : pad_left + w] = image[y1:y2, x1:x2]
113113

114114
# resize the cropped image
115115
image = cv2.resize(image_crop, (out_w, out_h), interpolation=cv2.INTER_LINEAR)

dlclive/pose_estimation_pytorch/dynamic_cropping.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class DynamicCropper:
3939
The margin used to expand an individuals bounding box before cropping it.
4040
4141
Examples:
42-
>>> import deeplabcut.pose_estimation_pytorch.models as models
42+
>>> import torch.nn as nn
4343
>>>
44-
>>> model: models.PoseModel
44+
>>> model: nn.Module # pose estimation model
4545
>>> frames: torch.Tensor # shape (num_frames, 3, H, W)
4646
>>>
4747
>>> dynamic = DynamicCropper(threshold=0.6, margin=25)
@@ -57,6 +57,7 @@ class DynamicCropper:
5757
>>> predictions.append(pose)
5858
>>>
5959
"""
60+
6061
threshold: float
6162
margin: int
6263
_crop: tuple[int, int, int, int] | None = field(default=None, repr=False)
@@ -424,16 +425,18 @@ def _prepare_bounding_box(
424425

425426
input_ratio = w / h
426427
if input_ratio > self._td_ratio: # h/w < h0/w0 => h' = w * h0/w0
427-
h = w / self._td_ratio
428+
h = w / self._td_ratio
428429
elif input_ratio < self._td_ratio: # w/h < w0/h0 => w' = h * w0/h0
429-
w = h * self._td_ratio
430+
w = h * self._td_ratio
430431

431432
x1, y1 = int(round(cx - (w / 2))), int(round(cy - (h / 2)))
432433
w, h = max(int(w), self.min_bbox_size[0]), max(int(h), self.min_bbox_size[1])
433434
return x1, y1, w, h
434435

435436
def _crop_bounding_box(
436-
self, image: torch.Tensor, bbox: tuple[int, int, int, int],
437+
self,
438+
image: torch.Tensor,
439+
bbox: tuple[int, int, int, int],
437440
) -> torch.Tensor:
438441
"""Applies a top-down crop to an image given a bounding box.
439442
@@ -487,7 +490,7 @@ def _extract_best_patch(self, pose: torch.Tensor) -> torch.Tensor:
487490
# set the crop to the one used for the best patch
488491
self._crop = self._patches[best_patch]
489492

490-
return pose[best_patch:best_patch + 1]
493+
return pose[best_patch : best_patch + 1]
491494

492495
def generate_patches(self) -> list[tuple[int, int, int, int]]:
493496
"""Generates patch coordinates for splitting an image.

0 commit comments

Comments
 (0)