Skip to content

Commit a29415d

Browse files
committed
Benchmark pytorch: introduce n_frames and progress bar
1 parent 9b5d8f1 commit a29415d

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from PIL import ImageColor
1313
from pip._internal.operations import freeze
1414
import torch
15+
from tqdm import tqdm
16+
1517
# torch import needs to switch order with "from pip._internal.operations import freeze" because of crash
1618
# see https://github.com/pytorch/pytorch/issues/140914
1719

@@ -91,6 +93,7 @@ def benchmark(
9193
device: str,
9294
single_animal: bool,
9395
save_dir=None,
96+
n_frames=1000,
9497
precision: str = "FP32",
9598
display=True,
9699
pcutoff=0.5,
@@ -122,6 +125,8 @@ def benchmark(
122125
save_dir : str, optional
123126
Directory to save output data and labeled video.
124127
If not specified, will use the directory of video_path, by default None
128+
n_frames : int, optional
129+
Number of frames to run inference on, by default 1000
125130
precision : str, optional, default='FP32'
126131
Precision type for the model ('FP32' or 'FP16').
127132
display : bool, optional, default=True
@@ -203,7 +208,14 @@ def benchmark(
203208
poses, times = [], []
204209
frame_index = 0
205210

206-
while True:
211+
total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
212+
n_frames = int(
213+
n_frames
214+
if (n_frames > 0) and n_frames < total_n_frames
215+
else total_n_frames
216+
)
217+
iterator = range(n_frames) if display else tqdm(range(n_frames))
218+
for i in iterator:
207219
ret, frame = cap.read()
208220
if not ret:
209221
break

0 commit comments

Comments
 (0)