22
22
import dlclive .pose_estimation_pytorch .models as models
23
23
import dlclive .pose_estimation_pytorch .dynamic_cropping as dynamic_cropping
24
24
from dlclive .core .runner import BaseRunner
25
+ from dlclive .pose_estimation_pytorch .data .image import AutoPadToDivisor
25
26
26
27
27
28
@dataclass
@@ -142,7 +143,8 @@ def __init__(
142
143
self .cfg = None
143
144
self .detector = None
144
145
self .model = None
145
- self .transform = None
146
+ self .detector_transform = None
147
+ self .pose_transform = None
146
148
147
149
# Parse Dynamic Cropping parameters
148
150
if isinstance (dynamic , dict ):
@@ -172,13 +174,7 @@ def close(self) -> None:
172
174
@torch .inference_mode ()
173
175
def get_pose (self , frame : np .ndarray ) -> np .ndarray :
174
176
c , h , w = frame .shape
175
- frame = (
176
- self .transform (torch .from_numpy (frame ).permute (2 , 0 , 1 ))
177
- .unsqueeze (0 )
178
- .to (self .device )
179
- )
180
- if self .precision == "FP16" :
181
- frame = frame .half ()
177
+ tensor = torch .from_numpy (frame ).permute (2 , 0 , 1 ) # CHW, still on CPU
182
178
183
179
offsets_and_scales = None
184
180
if self .detector is not None :
@@ -187,18 +183,32 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:
187
183
detections = self .top_down_config .skip_frames .get_detections ()
188
184
189
185
if detections is None :
190
- detections = self .detector (frame )[0 ]
186
+ # Apply detector transform before inference
187
+ detector_input = self .detector_transform (tensor ).unsqueeze (0 ).to (self .device )
188
+ if self .precision == "FP16" :
189
+ detector_input = detector_input .half ()
190
+ detections = self .detector (detector_input )[0 ]
191
191
192
- frame_batch , offsets_and_scales = self ._prepare_top_down (frame , detections )
192
+ frame_batch , offsets_and_scales = self ._prepare_top_down (tensor , detections )
193
193
if len (frame_batch ) == 0 :
194
194
offsets_and_scales = [(0 , 0 ), 1 ]
195
195
else :
196
- frame = frame_batch . to ( self . device )
196
+ tensor = frame_batch # still CHW, batched
197
197
198
198
if self .dynamic is not None :
199
- frame = self .dynamic .crop (frame )
199
+ tensor = self .dynamic .crop (tensor )
200
+
201
+ # Apply pose transform
202
+ model_input = self .pose_transform (tensor )
203
+ # Ensure 4D input: (N, C, H, W)
204
+ if model_input .dim () == 3 :
205
+ model_input = model_input .unsqueeze (0 )
206
+ # Send to device
207
+ model_input = model_input .to (self .device )
208
+ if self .precision == "FP16" :
209
+ model_input = model_input .half ()
200
210
201
- outputs = self .model (frame )
211
+ outputs = self .model (model_input )
202
212
batch_pose = self .model .get_predictions (outputs )["bodypart" ]["poses" ]
203
213
204
214
if self .dynamic is not None :
@@ -264,15 +274,18 @@ def load_model(self) -> None:
264
274
self .detector .to (self .device )
265
275
self .detector .load_state_dict (raw_data ["detector" ])
266
276
self .detector .eval ()
267
-
268
277
if self .precision == "FP16" :
269
278
self .detector = self .detector .half ()
270
279
271
280
if self .top_down_config is None :
272
281
self .top_down_config = TopDownConfig ()
273
-
274
282
self .top_down_config .read_config (self .cfg )
275
283
284
+ detector_transforms = [v2 .ToDtype (torch .float32 , scale = True )]
285
+ if self .cfg ["detector" ]["data" ]["inference" ].get ("normalize_images" , False ):
286
+ detector_transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
287
+ self .detector_transform = v2 .Compose (detector_transforms )
288
+
276
289
if isinstance (self .dynamic , dynamic_cropping .TopDownDynamicCropper ):
277
290
crop = self .cfg ["data" ]["inference" ].get ("top_down_crop" , {})
278
291
w , h = crop .get ("width" , 256 ), crop .get ("height" , 256 )
@@ -287,12 +300,18 @@ def load_model(self) -> None:
287
300
"Top-down models must either use a detector or a TopDownDynamicCropper."
288
301
)
289
302
290
- self .transform = v2 .Compose (
291
- [
292
- v2 .ToDtype (torch .float32 , scale = True ),
293
- v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
294
- ]
295
- )
303
+ pose_transforms = [v2 .ToDtype (torch .float32 , scale = True )]
304
+ auto_padding_cfg = self .cfg ["data" ]["inference" ].get ("auto_padding" , None )
305
+ if auto_padding_cfg :
306
+ pose_transforms .append (
307
+ AutoPadToDivisor (
308
+ pad_height_divisor = auto_padding_cfg .get ("pad_height_divisor" , 1 ),
309
+ pad_width_divisor = auto_padding_cfg .get ("pad_width_divisor" , 1 ),
310
+ )
311
+ )
312
+ if self .cfg ["data" ]["inference" ].get ("normalize_images" , False ):
313
+ pose_transforms .append (v2 .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]))
314
+ self .pose_transform = v2 .Compose (pose_transforms )
296
315
297
316
def read_config (self ) -> dict :
298
317
"""Reads the configuration file"""
@@ -306,8 +325,17 @@ def _prepare_top_down(
306
325
self , frame : torch .Tensor , detections : dict [str , torch .Tensor ]
307
326
):
308
327
"""Prepares a frame for top-down pose estimation."""
328
+ # Accept unbatched frame (C, H, W) or batched frame (1, C, H, W)
329
+ if frame .dim () == 4 :
330
+ if frame .size (0 ) != 1 :
331
+ raise ValueError (f"Expected batch size 1, got { frame .size (0 )} " )
332
+ frame = frame [0 ] # (C, H, W)
333
+ elif frame .dim () != 3 :
334
+ raise ValueError (f"Expected frame of shape (C, H, W) or (1, C, H, W), got { frame .shape } " )
335
+
309
336
bboxes , scores = detections ["boxes" ], detections ["scores" ]
310
337
bboxes = bboxes [scores >= self .top_down_config .bbox_cutoff ]
338
+
311
339
if len (bboxes ) > 0 and self .top_down_config .max_detections is not None :
312
340
bboxes = bboxes [: self .top_down_config .max_detections ]
313
341
@@ -316,7 +344,7 @@ def _prepare_top_down(
316
344
for bbox in bboxes :
317
345
x1 , y1 , x2 , y2 = bbox .tolist ()
318
346
cropped_frame , offset , scale = data .top_down_crop_torch (
319
- frame [ 0 ] ,
347
+ frame ,
320
348
(x1 , y1 , x2 - x1 , y2 - y1 ),
321
349
output_size = self .top_down_config .crop_size ,
322
350
margin = 0 ,
0 commit comments