@@ -92,14 +92,15 @@ def benchmark(
92
92
model_path : str ,
93
93
model_type : str ,
94
94
device : str ,
95
- single_animal : bool ,
95
+ resize : float | None = None ,
96
+ pixels : int | None = None ,
97
+ single_animal : bool = True ,
96
98
save_dir = None ,
97
99
n_frames = 1000 ,
98
100
precision : str = "FP32" ,
99
101
display = True ,
100
102
pcutoff = 0.5 ,
101
103
display_radius = 3 ,
102
- resize = None ,
103
104
cropping = None , # Adding cropping to the function parameters
104
105
dynamic = (False , 0.5 , 10 ),
105
106
save_poses = False ,
@@ -121,7 +122,12 @@ def benchmark(
121
122
Type of the model (e.g., 'onnx').
122
123
device : str
123
124
Device to run the model on ('cpu' or 'cuda').
124
- single_animal: bool
125
+ resize : float or None, optional
126
+ Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
127
+ pixels : int, optional
128
+ downsize image to this number of pixels, maintaining aspect ratio.
129
+ Can only use one of resize or pixels. If both are provided, will use pixels.
130
+ single_animal: bool, optional, default=True
125
131
Whether the video contains only one animal (True) or multiple animals (False).
126
132
save_dir : str, optional
127
133
Directory to save output data and labeled video.
@@ -136,8 +142,6 @@ def benchmark(
136
142
Probability cutoff below which keypoints are not visualized.
137
143
display_radius : int, optional, default=5
138
144
Radius of circles drawn for keypoints on video frames.
139
- resize : tuple of int (width, height) or None, optional
140
- Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
141
145
cropping : list of int or None, optional
142
146
Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied.
143
147
dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin)
@@ -160,6 +164,17 @@ def benchmark(
160
164
- poses (list of dict): List of pose data for each frame.
161
165
- times (list of float): List of inference times for each frame.
162
166
"""
167
+ # Load video
168
+ cap = cv2 .VideoCapture (video_path )
169
+ if not cap .isOpened ():
170
+ print (f"Error: Could not open video file { video_path } " )
171
+ return
172
+ im_size = (int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )), int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )))
173
+
174
+ if pixels is not None :
175
+ resize = np .sqrt (pixels / (im_size [0 ] * im_size [1 ]))
176
+ if resize is not None :
177
+ im_size = (int (im_size [0 ] * resize ), int (im_size [1 ] * resize ))
163
178
164
179
# Create the DLCLive object with cropping
165
180
dlc_live = DLCLive (
@@ -185,12 +200,6 @@ def benchmark(
185
200
# Get the current date and time as a string
186
201
timestamp = time .strftime ("%Y%m%d_%H%M%S" )
187
202
188
- # Load video
189
- cap = cv2 .VideoCapture (video_path )
190
- if not cap .isOpened ():
191
- print (f"Error: Could not open video file { video_path } " )
192
- return
193
-
194
203
# Retrieve bodypart names and number of keypoints
195
204
bodyparts = dlc_live .read_config ()["metadata" ]["bodyparts" ]
196
205
@@ -202,7 +211,7 @@ def benchmark(
202
211
num_keypoints = len (bodyparts ),
203
212
cmap = cmap ,
204
213
fps = cap .get (cv2 .CAP_PROP_FPS ),
205
- frame_size = ( int ( cap . get ( cv2 . CAP_PROP_FRAME_WIDTH )), int ( cap . get ( cv2 . CAP_PROP_FRAME_HEIGHT ))) ,
214
+ frame_size = im_size ,
206
215
)
207
216
208
217
# Start empty dict to save poses to for each frame
@@ -241,6 +250,7 @@ def benchmark(
241
250
draw_pose_and_write (
242
251
frame = frame ,
243
252
pose = pose ,
253
+ resize = resize ,
244
254
colors = colors ,
245
255
bodyparts = bodyparts ,
246
256
pcutoff = pcutoff ,
@@ -303,6 +313,7 @@ def setup_video_writer(
303
313
def draw_pose_and_write (
304
314
frame : np .ndarray ,
305
315
pose : np .ndarray ,
316
+ resize : float ,
306
317
colors : list [tuple [int , int , int ]],
307
318
bodyparts : list [str ],
308
319
pcutoff : float ,
@@ -313,6 +324,14 @@ def draw_pose_and_write(
313
324
if len (pose .shape ) == 2 :
314
325
pose = pose [None ]
315
326
327
+ if resize is not None and resize != 1.0 :
328
+ # Resize the frame
329
+ frame = cv2 .resize (frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR )
330
+
331
+ # Scale pose coordinates
332
+ pose = pose .copy ()
333
+ pose [..., :2 ] *= resize
334
+
316
335
# Visualize keypoints
317
336
for i in range (pose .shape [0 ]):
318
337
for j in range (pose .shape [1 ]):
0 commit comments