Skip to content

Commit 014b96b

Browse files
committed
Add detector_transform and pose_transform
HRNet-32 requires input images to have shape that is multiple of 32, this preprocessing part was missing. I replaced the single PyTorchRunner.transform attribute by a detector transform and a pose transform, each of these transforms are built using the model config. I added a AutoPadToDivisor transform based on torchvision.transforms.functional.pad().
1 parent ad62af4 commit 014b96b

File tree

2 files changed

+77
-22
lines changed

2 files changed

+77
-22
lines changed

dlclive/pose_estimation_pytorch/data/image.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,30 @@ def top_down_crop_torch(
138138
offset = x1, y1
139139
crop = F.resized_crop(image, y1, x1, h, w, [out_h, out_w])
140140
return crop, offset, scale
141+
142+
143+
class AutoPadToDivisor(torch.nn.Module):
144+
def __init__(self, pad_height_divisor: int = 1, pad_width_divisor: int = 1):
145+
super().__init__()
146+
self.pad_height_divisor = pad_height_divisor
147+
self.pad_width_divisor = pad_width_divisor
148+
149+
def forward(self, img: torch.Tensor) -> torch.Tensor:
150+
# Accepts either (C, H, W) or (N, C, H, W)
151+
if img.ndim == 3:
152+
img = img.unsqueeze(0) # add batch dim
153+
154+
assert img.ndim == 4, f"Expected 4D tensor, got shape {img.shape}"
155+
_, _, h, w = img.shape
156+
157+
target_h = ((h + self.pad_height_divisor - 1) // self.pad_height_divisor) * self.pad_height_divisor
158+
target_w = ((w + self.pad_width_divisor - 1) // self.pad_width_divisor) * self.pad_width_divisor
159+
160+
pad_h = target_h - h
161+
pad_w = target_w - w
162+
163+
# Pad (left, top, right, bottom)
164+
padding = (0, 0, pad_w, pad_h)
165+
166+
# Warning: this method returns the batched image, regardless if its input was batched or not
167+
return F.pad(img, padding, padding_mode="reflect")

dlclive/pose_estimation_pytorch/runner.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dlclive.pose_estimation_pytorch.models as models
2323
import dlclive.pose_estimation_pytorch.dynamic_cropping as dynamic_cropping
2424
from dlclive.core.runner import BaseRunner
25+
from dlclive.pose_estimation_pytorch.data.image import AutoPadToDivisor
2526

2627

2728
@dataclass
@@ -142,7 +143,8 @@ def __init__(
142143
self.cfg = None
143144
self.detector = None
144145
self.model = None
145-
self.transform = None
146+
self.detector_transform = None
147+
self.pose_transform = None
146148

147149
# Parse Dynamic Cropping parameters
148150
if isinstance(dynamic, dict):
@@ -172,13 +174,7 @@ def close(self) -> None:
172174
@torch.inference_mode()
173175
def get_pose(self, frame: np.ndarray) -> np.ndarray:
174176
c, h, w = frame.shape
175-
frame = (
176-
self.transform(torch.from_numpy(frame).permute(2, 0, 1))
177-
.unsqueeze(0)
178-
.to(self.device)
179-
)
180-
if self.precision == "FP16":
181-
frame = frame.half()
177+
tensor = torch.from_numpy(frame).permute(2, 0, 1) # CHW, still on CPU
182178

183179
offsets_and_scales = None
184180
if self.detector is not None:
@@ -187,18 +183,32 @@ def get_pose(self, frame: np.ndarray) -> np.ndarray:
187183
detections = self.top_down_config.skip_frames.get_detections()
188184

189185
if detections is None:
190-
detections = self.detector(frame)[0]
186+
# Apply detector transform before inference
187+
detector_input = self.detector_transform(tensor).unsqueeze(0).to(self.device)
188+
if self.precision == "FP16":
189+
detector_input = detector_input.half()
190+
detections = self.detector(detector_input)[0]
191191

192-
frame_batch, offsets_and_scales = self._prepare_top_down(frame, detections)
192+
frame_batch, offsets_and_scales = self._prepare_top_down(tensor, detections)
193193
if len(frame_batch) == 0:
194194
offsets_and_scales = [(0, 0), 1]
195195
else:
196-
frame = frame_batch.to(self.device)
196+
tensor = frame_batch # still CHW, batched
197197

198198
if self.dynamic is not None:
199-
frame = self.dynamic.crop(frame)
199+
tensor = self.dynamic.crop(tensor)
200+
201+
# Apply pose transform
202+
model_input = self.pose_transform(tensor)
203+
# Ensure 4D input: (N, C, H, W)
204+
if model_input.dim() == 3:
205+
model_input = model_input.unsqueeze(0)
206+
# Send to device
207+
model_input = model_input.to(self.device)
208+
if self.precision == "FP16":
209+
model_input = model_input.half()
200210

201-
outputs = self.model(frame)
211+
outputs = self.model(model_input)
202212
batch_pose = self.model.get_predictions(outputs)["bodypart"]["poses"]
203213

204214
if self.dynamic is not None:
@@ -264,15 +274,18 @@ def load_model(self) -> None:
264274
self.detector.to(self.device)
265275
self.detector.load_state_dict(raw_data["detector"])
266276
self.detector.eval()
267-
268277
if self.precision == "FP16":
269278
self.detector = self.detector.half()
270279

271280
if self.top_down_config is None:
272281
self.top_down_config = TopDownConfig()
273-
274282
self.top_down_config.read_config(self.cfg)
275283

284+
detector_transforms = [v2.ToDtype(torch.float32, scale=True)]
285+
if self.cfg["detector"]["data"]["inference"].get("normalize_images", False):
286+
detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
287+
self.detector_transform = v2.Compose(detector_transforms)
288+
276289
if isinstance(self.dynamic, dynamic_cropping.TopDownDynamicCropper):
277290
crop = self.cfg["data"]["inference"].get("top_down_crop", {})
278291
w, h = crop.get("width", 256), crop.get("height", 256)
@@ -287,12 +300,18 @@ def load_model(self) -> None:
287300
"Top-down models must either use a detector or a TopDownDynamicCropper."
288301
)
289302

290-
self.transform = v2.Compose(
291-
[
292-
v2.ToDtype(torch.float32, scale=True),
293-
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
294-
]
295-
)
303+
pose_transforms = [v2.ToDtype(torch.float32, scale=True)]
304+
auto_padding_cfg = self.cfg["data"]["inference"].get("auto_padding", None)
305+
if auto_padding_cfg:
306+
pose_transforms.append(
307+
AutoPadToDivisor(
308+
pad_height_divisor=auto_padding_cfg.get("pad_height_divisor", 1),
309+
pad_width_divisor=auto_padding_cfg.get("pad_width_divisor", 1),
310+
)
311+
)
312+
if self.cfg["data"]["inference"].get("normalize_images", False):
313+
pose_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
314+
self.pose_transform = v2.Compose(pose_transforms)
296315

297316
def read_config(self) -> dict:
298317
"""Reads the configuration file"""
@@ -306,8 +325,17 @@ def _prepare_top_down(
306325
self, frame: torch.Tensor, detections: dict[str, torch.Tensor]
307326
):
308327
"""Prepares a frame for top-down pose estimation."""
328+
# Accept unbatched frame (C, H, W) or batched frame (1, C, H, W)
329+
if frame.dim() == 4:
330+
if frame.size(0) != 1:
331+
raise ValueError(f"Expected batch size 1, got {frame.size(0)}")
332+
frame = frame[0] # (C, H, W)
333+
elif frame.dim() != 3:
334+
raise ValueError(f"Expected frame of shape (C, H, W) or (1, C, H, W), got {frame.shape}")
335+
309336
bboxes, scores = detections["boxes"], detections["scores"]
310337
bboxes = bboxes[scores >= self.top_down_config.bbox_cutoff]
338+
311339
if len(bboxes) > 0 and self.top_down_config.max_detections is not None:
312340
bboxes = bboxes[: self.top_down_config.max_detections]
313341

@@ -316,7 +344,7 @@ def _prepare_top_down(
316344
for bbox in bboxes:
317345
x1, y1, x2, y2 = bbox.tolist()
318346
cropped_frame, offset, scale = data.top_down_crop_torch(
319-
frame[0],
347+
frame,
320348
(x1, y1, x2 - x1, y2 - y1),
321349
output_size=self.top_down_config.crop_size,
322350
margin=0,

0 commit comments

Comments
 (0)