Skip to content

Commit 29c8122

Browse files
dikraMasrourn-poulsen
authored andcommitted
fix live inference and display, black and isort
1 parent e6a914a commit 29c8122

File tree

17 files changed

+147
-132
lines changed

17 files changed

+147
-132
lines changed

dlclive/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
Licensed under GNU Lesser General Public License v3.0
66
"""
77

8-
from dlclive.version import __version__, VERSION
9-
from dlclive.dlclive import DLCLive
108
from dlclive.display import Display
11-
from dlclive.processor import Processor
9+
from dlclive.dlclive import DLCLive
1210
from dlclive.predictor import HeatmapPredictor
11+
from dlclive.processor import Processor
12+
from dlclive.version import VERSION, __version__
13+
1314
# from dlclive.benchmark import benchmark, benchmark_videos, download_benchmarking_data

dlclive/benchmark.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,31 @@
55
Licensed under GNU Lesser General Public License v3.0
66
"""
77

8-
9-
import platform
108
import os
11-
import time
12-
import sys
13-
import warnings
9+
import pickle
10+
import platform
1411
import subprocess
12+
import sys
13+
import time
1514
import typing
16-
import pickle
15+
import warnings
16+
1717
import colorcet as cc
18-
from PIL import ImageColor
1918
import ruamel
19+
from PIL import ImageColor
2020

2121
try:
2222
from pip._internal.operations import freeze
2323
except ImportError:
2424
from pip.operations import freeze
2525

26-
from tqdm import tqdm
26+
import cv2
2727
import numpy as np
2828
import tensorflow as tf
29-
import cv2
29+
from tqdm import tqdm
3030

31-
from dlclive import DLCLive
32-
from dlclive import VERSION
31+
from dlclive import VERSION, DLCLive
3332
from dlclive import __file__ as dlcfile
34-
3533
from dlclive.utils import decode_fourcc
3634

3735

@@ -42,8 +40,9 @@ def download_benchmarking_data(
4240
"""
4341
Downloads a DeepLabCut-Live benchmarking Data (videos & DLC models).
4442
"""
45-
import urllib.request
4643
import tarfile
44+
import urllib.request
45+
4746
from tqdm import tqdm
4847

4948
def show_progress(count, block_size, total_size):
@@ -75,7 +74,7 @@ def tarfilenamecutting(tarf):
7574

7675

7776
def get_system_info() -> dict:
78-
""" Return summary info for system running benchmark
77+
"""Return summary info for system running benchmark
7978
Returns
8079
-------
8180
dict
@@ -165,7 +164,7 @@ def benchmark(
165164
save_video=False,
166165
output=None,
167166
) -> typing.Tuple[np.ndarray, tuple, bool, dict]:
168-
""" Analyze DeepLabCut-live exported model on a video:
167+
"""Analyze DeepLabCut-live exported model on a video:
169168
Calculate inference time,
170169
display keypoints, or
171170
get poses/create a labeled video
@@ -193,7 +192,7 @@ def benchmark(
193192
n_frames : int, optional
194193
number of frames to run inference on, by default 1000
195194
print_rate : bool, optional
196-
flat to print inference rate frame by frame, by default False
195+
flag to print inference rate frame by frame, by default False
197196
display : bool, optional
198197
flag to display keypoints on images. Useful for checking the accuracy of exported models.
199198
pcutoff : float, optional
@@ -440,7 +439,7 @@ def benchmark(
440439
def save_inf_times(
441440
sys_info, inf_times, im_size, TFGPUinference, model=None, meta=None, output=None
442441
):
443-
""" Save inference time data collected using :function:`benchmark` with system information to a pickle file.
442+
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
444443
This is primarily used through :function:`benchmark_videos`
445444
446445
@@ -666,8 +665,7 @@ def benchmark_videos(
666665

667666

668667
def main():
669-
"""Provides a command line interface :function:`benchmark_videos`
670-
"""
668+
"""Provides a command line interface :function:`benchmark_videos`"""
671669

672670
import argparse
673671

dlclive/benchmark_pytorch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from PIL import ImageColor
1414
from pip._internal.operations import freeze
1515

16-
from dlclive import VERSION, DLCLive
16+
from dlclive import DLCLive
17+
from dlclive.version import VERSION
1718

1819

1920
def get_system_info() -> dict:
@@ -482,8 +483,3 @@ def main():
482483

483484
if __name__ == "__main__":
484485
main()
485-
486-
487-
# Example how to run in command line:
488-
# python benchmark_pytorch.py /path/to/model /path/to/video DLC cuda -p FP32 -d -c 0.5 -dr 5 -r 0.5 -x 10 630 10 470 --save-poses --save-video --draw-keypoint-names --cmap bmy --save-dir
489-
# python benchmark_pytorch.py /Users/annastuckert/Documents/DLC_AI_Residency/DLC_AI2024/DeepLabCut-live/dlc-live-dummy/ventral-gait/resnet.onnx /Users/annastuckert/Documents/DLC_AI_Residency/DLC_AI2024/DeepLabCut-live/dlc-live-dummy/ventral-gait/1_20cms_0degUP_first_1s.avi DLC cuda -p FP32 -d -r 0.5 --save-poses --save-video --draw-keypoint-names --cmap bmy --save-dir /Users/annastuckert/Documents/DLC_AI_Residency/DLC_AI2024/DeepLabCut-live/dlc-live-dummy/ventral-gait/out

dlclive/check_install/check_install.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,16 @@
55
Licensed under GNU Lesser General Public License v3.0
66
"""
77

8-
9-
import sys
8+
import argparse
109
import shutil
11-
import warnings
12-
13-
from dlclive import benchmark_videos
10+
import sys
1411
import urllib.request
15-
import argparse
12+
import warnings
1613
from pathlib import Path
17-
from dlclibrary.dlcmodelzoo.modelzoo_download import (
18-
download_huggingface_model,
19-
)
2014

15+
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
16+
17+
from dlclive import benchmark_videos
2118

2219
MODEL_NAME = "superanimal_quadruped"
2320
SNAPSHOT_NAME = "snapshot-700000.pb"
@@ -27,28 +24,33 @@ def urllib_pbar(count, blockSize, totalSize):
2724
percent = int(count * blockSize * 100 / totalSize)
2825
outstr = f"{round(percent)}%"
2926
sys.stdout.write(outstr)
30-
sys.stdout.write("\b"*len(outstr))
27+
sys.stdout.write("\b" * len(outstr))
3128
sys.stdout.flush()
3229

3330

3431
def main():
3532
parser = argparse.ArgumentParser(
36-
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!")
37-
parser.add_argument('--nodisplay', action='store_false', help="Run the test without displaying tracking")
33+
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
34+
)
35+
parser.add_argument(
36+
"--nodisplay",
37+
action="store_false",
38+
help="Run the test without displaying tracking",
39+
)
3840
args = parser.parse_args()
3941
display = args.nodisplay
4042

4143
if not display:
42-
print('Running without displaying video')
44+
print("Running without displaying video")
4345

4446
# make temporary directory in $HOME
4547
# TODO: why create this temp directory in $HOME?
4648
print("\nCreating temporary directory...\n")
47-
tmp_dir = Path().home() / 'dlc-live-tmp'
48-
tmp_dir.mkdir(mode=0o775,exist_ok=True)
49+
tmp_dir = Path().home() / "dlc-live-tmp"
50+
tmp_dir.mkdir(mode=0o775, exist_ok=True)
4951

50-
video_file = str(tmp_dir / 'dog_clip.avi')
51-
model_dir = tmp_dir / 'DLC_Dog_resnet_50_iteration-0_shuffle-0'
52+
video_file = str(tmp_dir / "dog_clip.avi")
53+
model_dir = tmp_dir / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
5254

5355
# download dog test video from github:
5456
# TODO: Should check if the video's already there before downloading it (should have been cloned with the files)
@@ -58,25 +60,31 @@ def main():
5860

5961
# download model from the DeepLabCut Model Zoo
6062
if Path(model_dir / SNAPSHOT_NAME).exists():
61-
print('Model already downloaded, using cached version')
63+
print("Model already downloaded, using cached version")
6264
else:
6365
print("Downloading full_dog model from the DeepLabCut Model Zoo...")
6466
download_huggingface_model(MODEL_NAME, model_dir)
6567

6668
# assert these things exist so we can give informative error messages
6769
assert Path(video_file).exists(), f"Missing video file {video_file}"
68-
assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
70+
assert Path(
71+
model_dir / SNAPSHOT_NAME
72+
).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
6973

7074
# run benchmark videos
7175
print("\n Running inference...\n")
72-
benchmark_videos(str(model_dir), video_file, display=display, resize=0.5, pcutoff=0.25)
76+
benchmark_videos(
77+
str(model_dir), video_file, display=display, resize=0.5, pcutoff=0.25
78+
)
7379

7480
# deleting temporary files
7581
print("\n Deleting temporary files...\n")
7682
try:
7783
shutil.rmtree(tmp_dir)
7884
except PermissionError:
79-
warnings.warn(f'Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!')
85+
warnings.warn(
86+
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!"
87+
)
8088

8189
print("\nDone!\n")
8290

dlclive/display.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010
import colorcet as cc
1111
import numpy as np
12-
from dlclive import utils
1312
from PIL import Image, ImageDraw, ImageTk
1413

14+
from dlclive import utils
15+
1516

1617
class Display(object):
1718
"""

dlclive/dlclive.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from dlclive import utils
2525
from dlclive.display import Display
2626
from dlclive.exceptions import DLCLiveError, DLCLiveWarning
27-
from dlclive.pose import argmax_pose_predict, extract_cnn_output, multi_pose_predict
27+
from dlclive.pose import (argmax_pose_predict, extract_cnn_output,
28+
multi_pose_predict)
2829
from dlclive.predictor import HeatmapPredictor
2930

3031
if typing.TYPE_CHECKING:
@@ -277,7 +278,11 @@ def load_model(self):
277278
elif self.model_type == "onnx":
278279
model_paths = glob.glob(os.path.normpath(self.path + "/*.onnx"))
279280
if self.precision == "FP16":
280-
model_path = [model_paths[i] for i in range(len(model_paths)) if "fp16" in model_paths[i]][0]
281+
model_path = [
282+
model_paths[i]
283+
for i in range(len(model_paths))
284+
if "fp16" in model_paths[i]
285+
][0]
281286
print(model_path)
282287
else:
283288
model_path = model_paths[0]
@@ -294,13 +299,16 @@ def load_model(self):
294299
)
295300
# ! TODO implement if statements for choice of tensorrt engine options (precision, and caching)
296301
elif self.device == "tensorrt":
297-
provider = [("TensorrtExecutionProvider", {
298-
"trt_engine_cache_enable": True,
299-
"trt_engine_cache_path": "./trt_engines"
300-
})]
301-
self.sess = ort.InferenceSession(
302-
model_path, opts, providers=provider
303-
)
302+
provider = [
303+
(
304+
"TensorrtExecutionProvider",
305+
{
306+
"trt_engine_cache_enable": True,
307+
"trt_engine_cache_path": "./trt_engines",
308+
},
309+
)
310+
]
311+
self.sess = ort.InferenceSession(model_path, opts, providers=provider)
304312
self.predictor = HeatmapPredictor.build(self.cfg)
305313

306314
if not os.path.isfile(model_path):
@@ -333,7 +341,7 @@ def init_inference(self, frame=None, **kwargs):
333341
# load model
334342
self.load_model()
335343

336-
inf_time = 0.
344+
inf_time = 0.0
337345
# get pose of first frame (first inference is often very slow)
338346
if frame is not None:
339347
pose, inf_time = self.get_pose(frame, **kwargs)
@@ -356,7 +364,7 @@ def get_pose(self, frame=None, **kwargs):
356364
pose :class:`numpy.ndarray`
357365
the pose estimated by DeepLabCut for the input image
358366
"""
359-
inf_time = 0.
367+
inf_time = 0.0
360368
if frame is None:
361369
raise DLCLiveError("No frame provided for live pose estimation")
362370

@@ -381,8 +389,10 @@ def get_pose(self, frame=None, **kwargs):
381389
self.pose = self.pose["bodypart"]
382390

383391
elif self.model_type == "onnx":
384-
if self.precision == "FP32": frame = processed_frame.astype(np.float32)
385-
elif self.precision == "FP16": frame = processed_frame.astype(np.float16)
392+
if self.precision == "FP32":
393+
frame = processed_frame.astype(np.float32)
394+
elif self.precision == "FP16":
395+
frame = processed_frame.astype(np.float16)
386396

387397
frame = np.transpose(frame, (2, 0, 1))
388398
frame = np.expand_dims(frame, axis=0)

dlclive/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88

99
class DLCLiveError(Exception):
10-
""" Generic error type for incorrect use of the DLCLive class """
10+
"""Generic error type for incorrect use of the DLCLive class"""
1111

1212
pass
1313

1414

1515
class DLCLiveWarning(Warning):
16-
""" Generic warning for incorrect use of the DLCLive class """
16+
"""Generic warning for incorrect use of the DLCLive class"""
1717

1818
pass

dlclive/graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Licensed under GNU Lesser General Public License v3.0
66
"""
77

8-
98
import tensorflow as tf
109

1110
vers = (tf.__version__).split(".")

0 commit comments

Comments
 (0)