Skip to content

Commit f013557

Browse files
committed
WIP refactor benchmarking: resize and pixels
1 parent 6402380 commit f013557

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,15 @@ def benchmark(
9292
model_path: str,
9393
model_type: str,
9494
device: str,
95-
single_animal: bool,
95+
resize: float | None = None,
96+
pixels: int | None = None,
97+
single_animal: bool = True,
9698
save_dir=None,
9799
n_frames=1000,
98100
precision: str = "FP32",
99101
display=True,
100102
pcutoff=0.5,
101103
display_radius=3,
102-
resize=None,
103104
cropping=None, # Adding cropping to the function parameters
104105
dynamic=(False, 0.5, 10),
105106
save_poses=False,
@@ -121,7 +122,12 @@ def benchmark(
121122
Type of the model (e.g., 'onnx').
122123
device : str
123124
Device to run the model on ('cpu' or 'cuda').
124-
single_animal: bool
125+
resize : float or None, optional
126+
Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
127+
pixels : int, optional
128+
downsize image to this number of pixels, maintaining aspect ratio.
129+
Can only use one of resize or pixels. If both are provided, will use pixels.
130+
single_animal: bool, optional, default=True
125131
Whether the video contains only one animal (True) or multiple animals (False).
126132
save_dir : str, optional
127133
Directory to save output data and labeled video.
@@ -136,8 +142,6 @@ def benchmark(
136142
Probability cutoff below which keypoints are not visualized.
137143
display_radius : int, optional, default=5
138144
Radius of circles drawn for keypoints on video frames.
139-
resize : tuple of int (width, height) or None, optional
140-
Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
141145
cropping : list of int or None, optional
142146
Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied.
143147
dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin)
@@ -160,6 +164,17 @@ def benchmark(
160164
- poses (list of dict): List of pose data for each frame.
161165
- times (list of float): List of inference times for each frame.
162166
"""
167+
# Load video
168+
cap = cv2.VideoCapture(video_path)
169+
if not cap.isOpened():
170+
print(f"Error: Could not open video file {video_path}")
171+
return
172+
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
173+
174+
if pixels is not None:
175+
resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
176+
if resize is not None:
177+
im_size = (int(im_size[0] * resize), int(im_size[1] * resize))
163178

164179
# Create the DLCLive object with cropping
165180
dlc_live = DLCLive(
@@ -185,12 +200,6 @@ def benchmark(
185200
# Get the current date and time as a string
186201
timestamp = time.strftime("%Y%m%d_%H%M%S")
187202

188-
# Load video
189-
cap = cv2.VideoCapture(video_path)
190-
if not cap.isOpened():
191-
print(f"Error: Could not open video file {video_path}")
192-
return
193-
194203
# Retrieve bodypart names and number of keypoints
195204
bodyparts = dlc_live.read_config()["metadata"]["bodyparts"]
196205

@@ -202,7 +211,7 @@ def benchmark(
202211
num_keypoints=len(bodyparts),
203212
cmap=cmap,
204213
fps=cap.get(cv2.CAP_PROP_FPS),
205-
frame_size=(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))),
214+
frame_size=im_size,
206215
)
207216

208217
# Start empty dict to save poses to for each frame
@@ -241,6 +250,7 @@ def benchmark(
241250
draw_pose_and_write(
242251
frame=frame,
243252
pose=pose,
253+
resize=resize,
244254
colors=colors,
245255
bodyparts=bodyparts,
246256
pcutoff=pcutoff,
@@ -303,6 +313,7 @@ def setup_video_writer(
303313
def draw_pose_and_write(
304314
frame: np.ndarray,
305315
pose: np.ndarray,
316+
resize: float,
306317
colors: list[tuple[int, int, int]],
307318
bodyparts: list[str],
308319
pcutoff: float,
@@ -313,6 +324,14 @@ def draw_pose_and_write(
313324
if len(pose.shape) == 2:
314325
pose = pose[None]
315326

327+
if resize is not None and resize != 1.0:
328+
# Resize the frame
329+
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
330+
331+
# Scale pose coordinates
332+
pose = pose.copy()
333+
pose[..., :2] *= resize
334+
316335
# Visualize keypoints
317336
for i in range(pose.shape[0]):
318337
for j in range(pose.shape[1]):

0 commit comments

Comments
 (0)