Skip to content

Commit d61f892

Browse files
dikraMasrourn-poulsen
authored andcommitted
update docstrings, clean dlclive script
1 parent 9508a79 commit d61f892

File tree

5 files changed

+23
-86
lines changed

5 files changed

+23
-86
lines changed

dlclive/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dlclive.display import Display
99
from dlclive.dlclive import DLCLive
1010
from dlclive.predictor import HeatmapPredictor
11-
from dlclive.processor import Processor
11+
from dlclive.processor.processor import Processor
1212
from dlclive.version import VERSION, __version__
1313

1414
# from dlclive.benchmark import benchmark, benchmark_videos, download_benchmarking_data

dlclive/dlclive.py

Lines changed: 21 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from pathlib import Path
1414
from typing import List, Optional, Tuple
1515

16-
import deeplabcut as dlc
1716
import numpy as np
18-
import onnx
1917
import onnxruntime as ort
2018
import ruamel.yaml
2119
import torch
@@ -24,37 +22,11 @@
2422
from dlclive import utils
2523
from dlclive.display import Display
2624
from dlclive.exceptions import DLCLiveError, DLCLiveWarning
27-
from dlclive.pose import (argmax_pose_predict, extract_cnn_output,
28-
multi_pose_predict)
2925
from dlclive.predictor import HeatmapPredictor
3026

3127
if typing.TYPE_CHECKING:
3228
from dlclive.processor import Processor
3329

34-
35-
# TODO:
36-
# graph.py the main element to import TF model - convert to pytorch implementation
37-
# add pcutoffn to docstring
38-
39-
# Q: What is the best way to test the code as we go?
40-
# Q: if self.pose is not None: - ask Niels to go through this!
41-
42-
# Q: what exactly does model_type reference?
43-
# Q: is precision a type of qunatization?
44-
# Q: for dynamic: First key points are predicted, then dynamic cropping is performed to 'single out' the animal, and then pose is estimated, we think. What is the difference from key point prediction to pose prediction?
45-
# Q: what is the processor? see processor code F12 from init file - what is the 'user defined process' - could it be that if mouse = standing, perform some action? or is the process the prediction of a certain pose/set of keypoints
46-
# Q: why have the convert2rgb function, is the stream coming from the camera different from the input needed to DLC live?
47-
# Q: what is the parameter 'cfg'?
48-
49-
# What do these do?
50-
# self.inputs = None
51-
# self.outputs = None
52-
# self.tflite_interpreter = None
53-
# self.pose = None
54-
# self.is_initialized = False
55-
# self.sess = None
56-
57-
5830
class DLCLive(object):
5931
"""
6032
Object that loads a DLC network and performs inference on single images (e.g. images captured from a camera feed)
@@ -66,10 +38,10 @@ class DLCLive(object):
6638
Full path to exported model directory
6739
6840
model_type: string, optional
69-
which model to use: 'base', 'tensorrt' for tensorrt optimized graph, 'lite' for tensorflow lite optimized graph
41+
which model to use: 'pytorch' or 'onnx' for exported snapshot
7042
7143
precision : string, optional
72-
precision of model weights, only for model_type='tensorrt'. Can be 'FP16' (default), 'FP32', or 'INT8'
44+
precision of model weights, only for model_type='onnx'. Can be 'FP32' (default) or 'FP16'
7345
7446
cropping : list of int
7547
cropping parameters in pixel number: [x1, x2, y1, y2] #A: Maybe this is the dynamic cropping of each frame to speed of processing, so instead of analyzing the whole frame, it analyses only the part of the frame where the animal is
@@ -196,12 +168,7 @@ def parameterization(
196168
self,
197169
) -> (
198170
dict
199-
): # A: constructs a dictionary based on the object attributes based on the list of parameters
200-
"""
201-
Return
202-
Returns
203-
-------
204-
"""
171+
):
205172
return {param: getattr(self, param) for param in self.PARAMETERS}
206173

207174
def process_frame(self, frame):
@@ -219,8 +186,6 @@ def process_frame(self, frame):
219186
processed frame: convert type, crop, convert color
220187
"""
221188

222-
# ! NORMALISATION ??
223-
224189
if self.cropping:
225190
frame = frame[
226191
self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1]
@@ -230,9 +195,7 @@ def process_frame(self, frame):
230195
if self.pose is not None:
231196
detected = self.pose["poses"][0][0][:, 2] > self.dynamic[1]
232197

233-
# if np.any(detected.numpy()):
234198
if torch.any(detected):
235-
# if detected.any(): # Use PyTorch's any() method
236199

237200
x = self.pose["poses"][0][0][detected, 0]
238201
y = self.pose["poses"][0][0][detected, 1]
@@ -263,7 +226,7 @@ def process_frame(self, frame):
263226

264227
def load_model(self):
265228
if self.model_type == "pytorch":
266-
# Requires DLC 3.0 to be imported
229+
# Requires DLC 3.0 to be imported !
267230
model_path = os.path.join(self.path, self.snapshot)
268231
if not os.path.isfile(model_path):
269232
raise FileNotFoundError(
@@ -278,12 +241,7 @@ def load_model(self):
278241
elif self.model_type == "onnx":
279242
model_paths = glob.glob(os.path.normpath(self.path + "/*.onnx"))
280243
if self.precision == "FP16":
281-
model_path = [
282-
model_paths[i]
283-
for i in range(len(model_paths))
284-
if "fp16" in model_paths[i]
285-
][0]
286-
print(model_path)
244+
model_path = [model_paths[i] for i in range(len(model_paths)) if "fp16" in model_paths[i]][0]
287245
else:
288246
model_path = model_paths[0]
289247
opts = ort.SessionOptions()
@@ -292,23 +250,19 @@ def load_model(self):
292250
self.sess = ort.InferenceSession(
293251
model_path, opts, providers=["CUDAExecutionProvider"]
294252
)
295-
print(self.sess)
296253
elif self.device == "cpu":
297254
self.sess = ort.InferenceSession(
298255
model_path, opts, providers=["CPUExecutionProvider"]
299256
)
300-
# ! TODO implement if statements for choice of tensorrt engine options (precision, and caching)
257+
301258
elif self.device == "tensorrt":
302-
provider = [
303-
(
304-
"TensorrtExecutionProvider",
305-
{
306-
"trt_engine_cache_enable": True,
307-
"trt_engine_cache_path": "./trt_engines",
308-
},
309-
)
310-
]
311-
self.sess = ort.InferenceSession(model_path, opts, providers=provider)
259+
provider = [("TensorrtExecutionProvider", {
260+
"trt_engine_cache_enable": True,
261+
"trt_engine_cache_path": "./trt_engines"
262+
})]
263+
self.sess = ort.InferenceSession(
264+
model_path, opts, providers=provider
265+
)
312266
self.predictor = HeatmapPredictor.build(self.cfg)
313267

314268
if not os.path.isfile(model_path):
@@ -336,13 +290,15 @@ def init_inference(self, frame=None, **kwargs):
336290
--------
337291
pose :class:`numpy.ndarray`
338292
the pose estimated by DeepLabCut for the input image
293+
inf_time:class: `float`
294+
the pose inference time
339295
"""
340296

341297
# load model
342298
self.load_model()
343299

344-
inf_time = 0.0
345-
# get pose of first frame (first inference is often very slow)
300+
inf_time = 0.
301+
# get pose of first frame (first inference is very slow)
346302
if frame is not None:
347303
pose, inf_time = self.get_pose(frame, **kwargs)
348304
else:
@@ -363,8 +319,11 @@ def get_pose(self, frame=None, **kwargs):
363319
--------
364320
pose :class:`numpy.ndarray`
365321
the pose estimated by DeepLabCut for the input image
322+
inf_time:class: `float`
323+
the pose inference time
366324
"""
367-
inf_time = 0.0
325+
326+
inf_time = 0.
368327
if frame is None:
369328
raise DLCLiveError("No frame provided for live pose estimation")
370329

@@ -437,18 +396,7 @@ def get_pose(self, frame=None, **kwargs):
437396
self.pose["poses"][0][0][:, 1] += self.dynamic_cropping[2]
438397

439398
# process the pose
440-
441399
if self.processor:
442400
self.pose = self.processor.process(self.pose, **kwargs)
443401

444402
return self.pose, inf_time
445-
446-
# def close(self):
447-
# """ Close tensorflow session
448-
# """
449-
450-
# self.sess.close()
451-
# self.sess = None
452-
# self.is_initialized = False
453-
# if self.display is not None:
454-
# self.display.destroy()

dlclive/live_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def analyze_live_video(
175175
# Load video
176176
cap = cv2.VideoCapture(camera)
177177
if not cap.isOpened():
178-
print(f"Error: Could not open video file {camera}")
178+
print(f"Error: Could not open camera {camera}")
179179
return
180180

181181
# Start empty dict to save poses to for each frame

dlclive/processor/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,3 @@
44
55
Licensed under GNU Lesser General Public License v3.0
66
"""
7-
8-
from dlclive.processor.kalmanfilter import KalmanFilterPredictor
9-
from dlclive.processor.processor import Processor

test.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)