-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
In progress...
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from .bbox import decode # assume decode supports vectorized inputs
def detect(net, img, device):
# Transpose from (H, W, C) to (C, H, W)
img = img.transpose(2, 0, 1)
# Create a batch of 1. Use np.ascontiguousarray to avoid extra copies.
img = np.expand_dims(np.ascontiguousarray(img), 0)
img = torch.from_numpy(img).to(device, dtype=torch.float32)
return batch_detect(net, img, device)
def batch_detect(net, img_batch, device):
"""
Inputs:
- img_batch: a torch.Tensor of shape (Batch size, Channels, Height, Width)
"""
# It is better to set cudnn.benchmark globally (outside the function)
# rather than on every call (if using CUDA).
if 'cuda' in device:
torch.backends.cudnn.benchmark = True
# Make sure img_batch is on the correct device and in float32.
img_batch = img_batch.to(device, dtype=torch.float32)
# Convert RGB (assumed input) to BGR by flipping the channel dimension.
# (Could also use explicit channel indexing like img_batch = img_batch[:, [2,1,0],:,:])
img_batch = img_batch.flip(-3)
# Subtract the mean
mean = torch.tensor([104.0, 117.0, 123.0], device=device).view(1, 3, 1, 1)
img_batch = img_batch - mean
with torch.no_grad():
olist = net(img_batch)
# Apply softmax on all classification outputs. Assuming that every even-index output
# is a classification output:
olist = [F.softmax(o, dim=1) if idx % 2 == 0 else o for idx, o in enumerate(olist)]
# Transfer outputs to the CPU and convert to numpy.
olist = [o.cpu().numpy() for o in olist]
bboxlists = get_predictions(olist, img_batch.size(0))
return bboxlists
def get_predictions(olist, batch_size):
"""
Vectorized version that obtains candidate detections from the network outputs.
It groups detections per batch sample.
Returns a list of arrays, one per image in the batch, where each array is
of shape (N, 5) representing the 4 bounding box coordinates and the final score.
"""
# Create a list to hold detections for every image
detections_by_image = [[] for _ in range(batch_size)]
# Variances used in decoding
variances = [0.1, 0.2]
num_scales = len(olist) // 2
for i in range(num_scales):
# Get classification and regression results for this scale.
ocls = olist[i * 2] # shape: (batch, num_classes, H, W)
oreg = olist[i * 2 + 1] # shape: (batch, 4, H, W)
# Define the stride (note that 2**(i+2) gives 4,8,16,32,...)
stride = 2 ** (i + 2)
# Use vectorized thresholding: obtain all positions (across the batch) with score > 0.05
# Note: np.where returns a tuple (batch_inds, h_inds, w_inds)
batch_inds, h_inds, w_inds = np.where(ocls[:, 1, :, :] > 0.05)
if batch_inds.size == 0:
continue
# Compute the center coordinates based on stride.
axc = stride / 2 + w_inds * stride
ayc = stride / 2 + h_inds * stride
# Each candidate uses the same prior box dimensions at this scale.
priors = np.vstack((
axc,
ayc,
np.full_like(axc, stride * 4),
np.full_like(ayc, stride * 4)
)).T # shape: (N, 4)
# Gather the scores (expand dims for concatenation later)
scores = ocls[batch_inds, 1, h_inds, w_inds][:, None] # shape: (N, 1)
# Gather regression outputs for the same positions.
# Here, indexing is done on every detection: from oreg (batch, 4, H, W)
locs = oreg[batch_inds, :, h_inds, w_inds] # shape: (N, 4)
# Decode the location predictions using the priors and provided variances.
# (Assuming that decode is implemented to work with vectorized inputs.)
boxes = decode(locs, priors, variances) # expected shape: (N, 4)
# Concatenate the boxes with their scores.
detections = np.concatenate((boxes, scores), axis=1) # shape: (N, 5)
# Group detections by the image index
for b, det in zip(batch_inds, detections):
detections_by_image[b].append(det)
# For every image in the batch, convert list of detections into a numpy array.
for i in range(batch_size):
if detections_by_image[i]:
detections_by_image[i] = np.stack(detections_by_image[i], axis=0)
else:
# If no candidates, return an empty array with shape (0, 5)
detections_by_image[i] = np.empty((0, 5))
return detections_by_image
def flip_detect(net, img, device):
# Flips the image horizontally.
img = cv2.flip(img, 1)
b = detect(net, img, device)
bboxlist = np.zeros(b[0].shape) if b[0].size > 0 else np.empty((0, 5))
if bboxlist.size > 0:
# Adjust the bounding boxes to the original (flipped) image coordinates.
bboxlist[:, 0] = img.shape[1] - b[0][:, 2] # x_min
bboxlist[:, 1] = b[0][:, 1] # y_min remains the same
bboxlist[:, 2] = img.shape[1] - b[0][:, 0] # x_max
bboxlist[:, 3] = b[0][:, 3] # y_max remains the same
bboxlist[:, 4] = b[0][:, 4] # score
return bboxlist
def pts_to_bb(pts):
# Converts a set of points to a bounding box
min_xy = np.min(pts, axis=0)
max_xy = np.max(pts, axis=0)
return np.array([min_xy[0], min_xy[1], max_xy[0], max_xy[1]])Hemilibeatriz
Metadata
Metadata
Assignees
Labels
No labels