Skip to content

Commit af0a907

Browse files
author
Beat Buesser
committed
Address review comments
Signed-off-by: Beat Buesser <[email protected]>
1 parent b159a4b commit af0a907

File tree

1 file changed

+69
-34
lines changed

1 file changed

+69
-34
lines changed

art/estimators/object_tracking/pytorch_goturn.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,13 @@ def loss_gradient( # pylint: disable=W0613
314314

315315
return grads
316316

317-
def _preprocess(self, img):
317+
def _preprocess(self, img: "torch.Tensor") -> "torch.Tensor":
318318
"""
319-
Preprocess image before forward pass, this is the same
320-
preprocessing used during training, please refer to collate function
321-
in train.py for reference
322-
@image: input image
319+
Preprocess image before forward pass, this is the same preprocessing used during training, please refer to
320+
collate function in train.py for reference
321+
322+
:param img: Single frame od shape (nb_samples, height, width, nb_channels).
323+
:return: Preprocessed frame.
323324
"""
324325
import torch # lgtm [py/repeated-import]
325326
from torch.nn.functional import interpolate
@@ -336,52 +337,73 @@ def _preprocess(self, img):
336337
img = (img - mean) / std
337338
return img
338339

339-
def _track(self, curr_frame, prev_frame, rect):
340-
"""track current frame
341-
@curr_frame: current frame
342-
@prev_frame: prev frame
343-
@rect: bounding box of previous frame
340+
def _track_step(
341+
self, curr_frame: "torch.Tensor", prev_frame: "torch.Tensor", rect: "torch.Tensor"
342+
) -> "torch.Tensor":
343+
"""
344+
Track current frame.
345+
346+
:param curr_frame: Current frame.
347+
:param prev_frame: Previous frame.
348+
:return: bounding box of previous frame
344349
"""
345350
import torch # lgtm [py/repeated-import]
346351

347352
prev_bbox = rect
348353

349354
k_context_factor = 2
350355

351-
def compute_output_height_f(bbox_tight):
356+
def compute_output_height_f(bbox_tight: "torch.Tensor") -> float:
352357
"""
353-
Height of search/target region.
358+
Compute height of search/target region.
359+
360+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
361+
:return: Output height.
354362
"""
355363
bbox_height = bbox_tight[3] - bbox_tight[1]
356364
output_height = k_context_factor * bbox_height
357365

358366
return max(1.0, output_height)
359367

360-
def compute_output_width_f(bbox_tight):
368+
def compute_output_width_f(bbox_tight: "torch.Tensor") -> float:
361369
"""
362-
Width of search/target region.
370+
Compute width of search/target region.
371+
372+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
373+
:return: Output width.
363374
"""
364375
bbox_width = bbox_tight[2] - bbox_tight[0]
365376
output_width = k_context_factor * bbox_width
366377

367378
return max(1.0, output_width)
368379

369-
def get_center_x_f(bbox_tight):
380+
def get_center_x_f(bbox_tight: "torch.Tensor") -> "torch.Tensor":
370381
"""
371-
x-coordinate of the bounding box center
382+
Compute x-coordinate of the bounding box center.
383+
384+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
385+
:return: x-coordinate of the bounding box center.
372386
"""
373387
return (bbox_tight[0] + bbox_tight[2]) / 2.0
374388

375-
def get_center_y_f(bbox_tight):
389+
def get_center_y_f(bbox_tight: "torch.Tensor") -> "torch.Tensor":
376390
"""
377-
y-coordinate of the bounding box center
391+
Compute y-coordinate of the bounding box center
392+
393+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
394+
:return: y-coordinate of the bounding box center.
378395
"""
379396
return (bbox_tight[1] + bbox_tight[3]) / 2.0
380397

381-
def compute_crop_pad_image_location(bbox_tight, image):
398+
def compute_crop_pad_image_location(
399+
bbox_tight: "torch.Tensor", image: "torch.Tensor"
400+
) -> (float, float, float, float):
382401
"""
383-
Get the valid image coordinates for the context region in target
384-
or search region in full image
402+
Get the valid image coordinates for the context region in target or search region in full image
403+
404+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
405+
:param image: Frame to be cropped and padded.
406+
:return: x-coordinate of the bounding box center.
385407
"""
386408

387409
# Center of the bounding box
@@ -424,30 +446,37 @@ def compute_crop_pad_image_location(bbox_tight, image):
424446
# return objPadImageLocation
425447
return roi_left, roi_bottom, roi_left + roi_width, roi_bottom + roi_height
426448

427-
def edge_spacing_x_f(bbox_tight):
449+
def edge_spacing_x_f(bbox_tight: "torch.Tensor") -> float:
428450
"""
429-
Edge spacing X to take care of if search/target pad region goes
430-
out of bound
451+
Edge spacing X to take care of if search/target pad region goes out of bound.
452+
453+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
454+
:return: Edge spacing X.
431455
"""
432456
output_width = compute_output_width_f(bbox_tight)
433457
bbox_center_x = get_center_x_f(bbox_tight)
434458

435459
return max(0.0, (output_width / 2) - bbox_center_x)
436460

437-
def edge_spacing_y_f(bbox_tight):
461+
def edge_spacing_y_f(bbox_tight: "torch.Tensor") -> float:
438462
"""
439-
Edge spacing X to take care of if search/target pad region goes
440-
out of bound
463+
Edge spacing X to take care of if search/target pad region goes out of bound.
464+
465+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
466+
:return: Edge spacing X.
441467
"""
442468
output_height = compute_output_height_f(bbox_tight)
443469
bbox_center_y = get_center_y_f(bbox_tight)
444470

445471
return max(0.0, (output_height / 2) - bbox_center_y)
446472

447-
def crop_pad_image(bbox_tight, image):
473+
def crop_pad_image(bbox_tight: "torch.Tensor", image: "torch.Tensor") -> ("torch.Tensor", float, float, float):
448474
"""
449-
Around the bounding box, we define a extra context factor of 2,
450-
which we will crop from the original image
475+
Around the bounding box, we define a extra context factor of 2, which we will crop from the original image.
476+
477+
:param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
478+
:param image: Frame to be cropped and padded.
479+
:return: Cropped and Padded image.
451480
"""
452481
import math
453482
import torch # lgtm [py/repeated-import]
@@ -525,8 +554,14 @@ def crop_pad_image(bbox_tight, image):
525554

526555
return pred_bb
527556

528-
def track(self, x, y_init):
529-
"""Track"""
557+
def _track(self, x: "torch.Tensor", y_init: "torch.Tensor") -> "torch.Tensor":
558+
"""
559+
Track object across frames.
560+
561+
:param x: A single video of shape (nb_frames, nb_height, nb_width, nb_channels)
562+
:param y_init: Initial bounding box around object on the first frame of `x`.
563+
:return: Predicted bounding box coordinates for all frames of shape (nb_frames, 4) in format [x1, y1, x2, y2].
564+
"""
530565
import torch # lgtm [py/repeated-import]
531566

532567
num_frames = x.shape[0]
@@ -536,7 +571,7 @@ def track(self, x, y_init):
536571

537572
for i in range(1, num_frames):
538573
curr = x[i]
539-
bbox_0 = self._track(curr, prev, bbox_0)
574+
bbox_0 = self._track_step(curr, prev, bbox_0)
540575
bbox = bbox_0
541576
prev = curr
542577

@@ -586,7 +621,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
586621
# Apply preprocessing
587622
x_i, _ = self._apply_preprocessing(x_i, y=None, fit=False)
588623

589-
y_pred = self.track(x=x_i, y_init=y_init[i])
624+
y_pred = self._track(x=x_i, y_init=y_init[i])
590625

591626
prediction_dict = dict()
592627
prediction_dict["boxes"] = y_pred.detach().cpu().numpy()

0 commit comments

Comments
 (0)