@@ -314,12 +314,13 @@ def loss_gradient( # pylint: disable=W0613
314314
315315 return grads
316316
317- def _preprocess (self , img ) :
317+ def _preprocess (self , img : "torch.Tensor" ) -> "torch.Tensor" :
318318 """
319- Preprocess image before forward pass, this is the same
320- preprocessing used during training, please refer to collate function
321- in train.py for reference
322- @image: input image
319+ Preprocess image before forward pass, this is the same preprocessing used during training, please refer to
320+ collate function in train.py for reference
321+
322+ :param img: Single frame od shape (nb_samples, height, width, nb_channels).
323+ :return: Preprocessed frame.
323324 """
324325 import torch # lgtm [py/repeated-import]
325326 from torch .nn .functional import interpolate
@@ -336,52 +337,73 @@ def _preprocess(self, img):
336337 img = (img - mean ) / std
337338 return img
338339
339- def _track (self , curr_frame , prev_frame , rect ):
340- """track current frame
341- @curr_frame: current frame
342- @prev_frame: prev frame
343- @rect: bounding box of previous frame
340+ def _track_step (
341+ self , curr_frame : "torch.Tensor" , prev_frame : "torch.Tensor" , rect : "torch.Tensor"
342+ ) -> "torch.Tensor" :
343+ """
344+ Track current frame.
345+
346+ :param curr_frame: Current frame.
347+ :param prev_frame: Previous frame.
348+ :return: bounding box of previous frame
344349 """
345350 import torch # lgtm [py/repeated-import]
346351
347352 prev_bbox = rect
348353
349354 k_context_factor = 2
350355
351- def compute_output_height_f (bbox_tight ) :
356+ def compute_output_height_f (bbox_tight : "torch.Tensor" ) -> float :
352357 """
353- Height of search/target region.
358+ Compute height of search/target region.
359+
360+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
361+ :return: Output height.
354362 """
355363 bbox_height = bbox_tight [3 ] - bbox_tight [1 ]
356364 output_height = k_context_factor * bbox_height
357365
358366 return max (1.0 , output_height )
359367
360- def compute_output_width_f (bbox_tight ) :
368+ def compute_output_width_f (bbox_tight : "torch.Tensor" ) -> float :
361369 """
362- Width of search/target region.
370+ Compute width of search/target region.
371+
372+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
373+ :return: Output width.
363374 """
364375 bbox_width = bbox_tight [2 ] - bbox_tight [0 ]
365376 output_width = k_context_factor * bbox_width
366377
367378 return max (1.0 , output_width )
368379
369- def get_center_x_f (bbox_tight ) :
380+ def get_center_x_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
370381 """
371- x-coordinate of the bounding box center
382+ Compute x-coordinate of the bounding box center.
383+
384+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
385+ :return: x-coordinate of the bounding box center.
372386 """
373387 return (bbox_tight [0 ] + bbox_tight [2 ]) / 2.0
374388
375- def get_center_y_f (bbox_tight ) :
389+ def get_center_y_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
376390 """
377- y-coordinate of the bounding box center
391+ Compute y-coordinate of the bounding box center
392+
393+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
394+ :return: y-coordinate of the bounding box center.
378395 """
379396 return (bbox_tight [1 ] + bbox_tight [3 ]) / 2.0
380397
381- def compute_crop_pad_image_location (bbox_tight , image ):
398+ def compute_crop_pad_image_location (
399+ bbox_tight : "torch.Tensor" , image : "torch.Tensor"
400+ ) -> (float , float , float , float ):
382401 """
383- Get the valid image coordinates for the context region in target
384- or search region in full image
402+ Get the valid image coordinates for the context region in target or search region in full image
403+
404+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
405+ :param image: Frame to be cropped and padded.
406+ :return: x-coordinate of the bounding box center.
385407 """
386408
387409 # Center of the bounding box
@@ -424,30 +446,37 @@ def compute_crop_pad_image_location(bbox_tight, image):
424446 # return objPadImageLocation
425447 return roi_left , roi_bottom , roi_left + roi_width , roi_bottom + roi_height
426448
427- def edge_spacing_x_f (bbox_tight ) :
449+ def edge_spacing_x_f (bbox_tight : "torch.Tensor" ) -> float :
428450 """
429- Edge spacing X to take care of if search/target pad region goes
430- out of bound
451+ Edge spacing X to take care of if search/target pad region goes out of bound.
452+
453+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
454+ :return: Edge spacing X.
431455 """
432456 output_width = compute_output_width_f (bbox_tight )
433457 bbox_center_x = get_center_x_f (bbox_tight )
434458
435459 return max (0.0 , (output_width / 2 ) - bbox_center_x )
436460
437- def edge_spacing_y_f (bbox_tight ) :
461+ def edge_spacing_y_f (bbox_tight : "torch.Tensor" ) -> float :
438462 """
439- Edge spacing X to take care of if search/target pad region goes
440- out of bound
463+ Edge spacing X to take care of if search/target pad region goes out of bound.
464+
465+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
466+ :return: Edge spacing X.
441467 """
442468 output_height = compute_output_height_f (bbox_tight )
443469 bbox_center_y = get_center_y_f (bbox_tight )
444470
445471 return max (0.0 , (output_height / 2 ) - bbox_center_y )
446472
447- def crop_pad_image (bbox_tight , image ):
473+ def crop_pad_image (bbox_tight : "torch.Tensor" , image : "torch.Tensor" ) -> ( "torch.Tensor" , float , float , float ):
448474 """
449- Around the bounding box, we define a extra context factor of 2,
450- which we will crop from the original image
475+ Around the bounding box, we define a extra context factor of 2, which we will crop from the original image.
476+
477+ :param bbox_tight: Coordinates of bounding box [x1, y1, x2, y2].
478+ :param image: Frame to be cropped and padded.
479+ :return: Cropped and Padded image.
451480 """
452481 import math
453482 import torch # lgtm [py/repeated-import]
@@ -525,8 +554,14 @@ def crop_pad_image(bbox_tight, image):
525554
526555 return pred_bb
527556
528- def track (self , x , y_init ):
529- """Track"""
557+ def _track (self , x : "torch.Tensor" , y_init : "torch.Tensor" ) -> "torch.Tensor" :
558+ """
559+ Track object across frames.
560+
561+ :param x: A single video of shape (nb_frames, nb_height, nb_width, nb_channels)
562+ :param y_init: Initial bounding box around object on the first frame of `x`.
563+ :return: Predicted bounding box coordinates for all frames of shape (nb_frames, 4) in format [x1, y1, x2, y2].
564+ """
530565 import torch # lgtm [py/repeated-import]
531566
532567 num_frames = x .shape [0 ]
@@ -536,7 +571,7 @@ def track(self, x, y_init):
536571
537572 for i in range (1 , num_frames ):
538573 curr = x [i ]
539- bbox_0 = self ._track (curr , prev , bbox_0 )
574+ bbox_0 = self ._track_step (curr , prev , bbox_0 )
540575 bbox = bbox_0
541576 prev = curr
542577
@@ -586,7 +621,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
586621 # Apply preprocessing
587622 x_i , _ = self ._apply_preprocessing (x_i , y = None , fit = False )
588623
589- y_pred = self .track (x = x_i , y_init = y_init [i ])
624+ y_pred = self ._track (x = x_i , y_init = y_init [i ])
590625
591626 prediction_dict = dict ()
592627 prediction_dict ["boxes" ] = y_pred .detach ().cpu ().numpy ()
0 commit comments