11# Copyright (c) Facebook, Inc. and its affiliates.
22from __future__ import division
3- from typing import Any , List , Tuple
3+ from typing import Any , Dict , List , Optional , Tuple
44import torch
55from torch import device
66from torch .nn import functional as F
@@ -57,7 +57,10 @@ def device(self) -> device:
5757
5858 @staticmethod
5959 def from_tensors (
60- tensors : List [torch .Tensor ], size_divisibility : int = 0 , pad_value : float = 0.0
60+ tensors : List [torch .Tensor ],
61+ size_divisibility : int = 0 ,
62+ pad_value : float = 0.0 ,
63+ padding_constraints : Optional [Dict [str , int ]] = None ,
6164 ) -> "ImageList" :
6265 """
6366 Args:
@@ -67,7 +70,11 @@ def from_tensors(
6770 size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
6871 the common height and width is divisible by `size_divisibility`.
6972 This depends on the model and many models need a divisibility of 32.
70- pad_value (float): value to pad
73+ pad_value (float): value to pad.
74+ padding_constraints (optional[Dict]): If given, it would follow the format as
75+ {"size_divisibility": int, "square": int}, where `size_divisibility` will overwrite
76+ the above one if presented and `square` indicates if require inputs to be padded to
77+ square.
7178
7279 Returns:
7380 an `ImageList`.
@@ -82,6 +89,12 @@ def from_tensors(
8289 image_sizes_tensor = [shapes_to_tensor (x ) for x in image_sizes ]
8390 max_size = torch .stack (image_sizes_tensor ).max (0 ).values
8491
92+ if padding_constraints is not None :
93+ if padding_constraints .get ("square" , 0 ) > 0 :
94+ # pad to square.
95+ max_size [0 ] = max_size [1 ] = max_size .max ()
96+ if "size_divisibility" in padding_constraints :
97+ size_divisibility = padding_constraints ["size_divisibility" ]
8598 if size_divisibility > 1 :
8699 stride = size_divisibility
87100 # the last two dims are H,W, both subject to divisibility requirement
0 commit comments