@@ -92,14 +92,15 @@ def benchmark(
9292 model_path : str ,
9393 model_type : str ,
9494 device : str ,
95- single_animal : bool ,
95+ resize : float | None = None ,
96+ pixels : int | None = None ,
97+ single_animal : bool = True ,
9698 save_dir = None ,
9799 n_frames = 1000 ,
98100 precision : str = "FP32" ,
99101 display = True ,
100102 pcutoff = 0.5 ,
101103 display_radius = 3 ,
102- resize = None ,
103104 cropping = None , # Adding cropping to the function parameters
104105 dynamic = (False , 0.5 , 10 ),
105106 save_poses = False ,
@@ -121,7 +122,12 @@ def benchmark(
121122 Type of the model (e.g., 'onnx').
122123 device : str
123124 Device to run the model on ('cpu' or 'cuda').
124- single_animal: bool
125+ resize : float or None, optional
126+ Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
127+ pixels : int, optional
128+ downsize image to this number of pixels, maintaining aspect ratio.
129+ Can only use one of resize or pixels. If both are provided, will use pixels.
130+ single_animal: bool, optional, default=True
125131 Whether the video contains only one animal (True) or multiple animals (False).
126132 save_dir : str, optional
127133 Directory to save output data and labeled video.
@@ -136,8 +142,6 @@ def benchmark(
136142 Probability cutoff below which keypoints are not visualized.
137143 display_radius : int, optional, default=5
138144 Radius of circles drawn for keypoints on video frames.
139- resize : tuple of int (width, height) or None, optional
140- Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
141145 cropping : list of int or None, optional
142146 Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied.
143147 dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin)
@@ -160,6 +164,17 @@ def benchmark(
160164 - poses (list of dict): List of pose data for each frame.
161165 - times (list of float): List of inference times for each frame.
162166 """
167+ # Load video
168+ cap = cv2 .VideoCapture (video_path )
169+ if not cap .isOpened ():
170+ print (f"Error: Could not open video file { video_path } " )
171+ return
172+ im_size = (int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )), int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )))
173+
174+ if pixels is not None :
175+ resize = np .sqrt (pixels / (im_size [0 ] * im_size [1 ]))
176+ if resize is not None :
177+ im_size = (int (im_size [0 ] * resize ), int (im_size [1 ] * resize ))
163178
164179 # Create the DLCLive object with cropping
165180 dlc_live = DLCLive (
@@ -185,12 +200,6 @@ def benchmark(
185200 # Get the current date and time as a string
186201 timestamp = time .strftime ("%Y%m%d_%H%M%S" )
187202
188- # Load video
189- cap = cv2 .VideoCapture (video_path )
190- if not cap .isOpened ():
191- print (f"Error: Could not open video file { video_path } " )
192- return
193-
194203 # Retrieve bodypart names and number of keypoints
195204 bodyparts = dlc_live .read_config ()["metadata" ]["bodyparts" ]
196205
@@ -202,7 +211,7 @@ def benchmark(
202211 num_keypoints = len (bodyparts ),
203212 cmap = cmap ,
204213 fps = cap .get (cv2 .CAP_PROP_FPS ),
205- frame_size = ( int ( cap . get ( cv2 . CAP_PROP_FRAME_WIDTH )), int ( cap . get ( cv2 . CAP_PROP_FRAME_HEIGHT ))) ,
214+ frame_size = im_size ,
206215 )
207216
208217 # Start empty dict to save poses to for each frame
@@ -241,6 +250,7 @@ def benchmark(
241250 draw_pose_and_write (
242251 frame = frame ,
243252 pose = pose ,
253+ resize = resize ,
244254 colors = colors ,
245255 bodyparts = bodyparts ,
246256 pcutoff = pcutoff ,
@@ -303,6 +313,7 @@ def setup_video_writer(
303313def draw_pose_and_write (
304314 frame : np .ndarray ,
305315 pose : np .ndarray ,
316+ resize : float ,
306317 colors : list [tuple [int , int , int ]],
307318 bodyparts : list [str ],
308319 pcutoff : float ,
@@ -313,6 +324,14 @@ def draw_pose_and_write(
313324 if len (pose .shape ) == 2 :
314325 pose = pose [None ]
315326
327+ if resize is not None and resize != 1.0 :
328+ # Resize the frame
329+ frame = cv2 .resize (frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR )
330+
331+ # Scale pose coordinates
332+ pose = pose .copy ()
333+ pose [..., :2 ] *= resize
334+
316335 # Visualize keypoints
317336 for i in range (pose .shape [0 ]):
318337 for j in range (pose .shape [1 ]):
0 commit comments