@@ -169,6 +169,7 @@ def __init__(
169169 between 0 and H and 0 and W
170170 - labels (Tensor[N]): the predicted labels for each image
171171 - scores (Tensor[N]): the scores or each prediction
172+ :param input_shape: Tuple of the form `(height, width)` of ints representing input image height and width
172173 :param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
173174 maximum values allowed for features. If floats are provided, these will be used as the range of all
174175 features. If arrays are provided, each value will be considered the bound for a feature, thus
@@ -577,43 +578,10 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
577578 - labels [N]: the labels for each image
578579 - scores [N]: the scores or each prediction.
579580 """
580- import cv2
581581 import torch
582582
583- # check if image with min, max dimensions, if not scale to 1000
584- # if is within min, max dims, but not square, resize to max of image
585- if (
586- self ._input_shape [1 ] < self .MIN_IMAGE_SIZE
587- or self ._input_shape [1 ] > self .MAX_IMAGE_SIZE
588- or self ._input_shape [2 ] < self .MIN_IMAGE_SIZE
589- or self .input_shape [2 ] > self .MAX_IMAGE_SIZE
590- ):
591- resized_imgs = []
592- for i , _ in enumerate (x ):
593- resized_imgs .append (
594- cv2 .resize (
595- (x * 255 )[i ].transpose (1 , 2 , 0 ).astype (np .uint8 ),
596- dsize = (1000 , 1000 ),
597- interpolation = cv2 .INTER_CUBIC ,
598- )
599- )
600- x = (np .array (resized_imgs ) / 255 ).transpose (0 , 3 , 1 , 2 ).astype (np .float32 )
601- elif self ._input_shape [1 ] != self ._input_shape [2 ]:
602- rescale_dim = max (self ._input_shape [1 ], self ._input_shape [2 ])
603- resized_imgs = []
604- for i , _ in enumerate (x ):
605- resized_imgs .append (
606- cv2 .resize (
607- (x * 255 )[i ].transpose (1 , 2 , 0 ).astype (np .uint8 ),
608- dsize = (rescale_dim , rescale_dim ),
609- interpolation = cv2 .INTER_CUBIC ,
610- )
611- )
612- x = (np .array (resized_imgs ) / 255 ).transpose (0 , 3 , 1 , 2 ).astype (np .float32 )
613-
614- x = x .copy ()
615-
616583 self ._model .eval ()
584+ x , _ = self ._apply_resizing (x , None )
617585
618586 # Apply preprocessing
619587 x_preprocessed , _ = self ._apply_preprocessing (x , y = None , fit = False )
@@ -633,7 +601,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
633601 predictions .append (
634602 {
635603 "boxes" : rescale_bboxes (
636- model_output ["pred_boxes" ][i , :, :], (self ._input_shape [1 ], self ._input_shape [2 ])
604+ model_output ["pred_boxes" ][i , :, :], (self ._input_shape [2 ], self ._input_shape [1 ])
637605 )
638606 .detach ()
639607 .numpy (),
@@ -765,22 +733,8 @@ def loss_gradient(
765733 - labels (Tensor[N]): the predicted labels for each image
766734 :return: Loss gradients of the same shape as `x`.
767735 """
768- import torch
769-
770- _y = []
771- for target in y :
772- cxcy_norm = revert_rescale_bboxes (
773- torch .from_numpy (target ["boxes" ]), (self .input_shape [1 ], self .input_shape [2 ])
774- )
775- _y .append (
776- {
777- "labels" : torch .from_numpy (target ["labels" ]).type (torch .int64 ).to (self .device ),
778- "boxes" : cxcy_norm .to (self .device ),
779- "scores" : torch .from_numpy (target ["scores" ]).type (torch .float ).to (self .device ),
780- }
781- )
782-
783- output , inputs_t , image_tensor_list_grad = self ._get_losses (x = x , y = _y )
736+ x , y = self ._apply_resizing (x , y )
737+ output , inputs_t , image_tensor_list_grad = self ._get_losses (x = x , y = y )
784738 loss = sum (output [k ] * self .weight_dict [k ] for k in output .keys () if k in self .weight_dict )
785739
786740 self ._model .zero_grad ()
@@ -833,6 +787,7 @@ def compute_losses(
833787 - scores (Tensor[N]): the scores or each prediction.
834788 :return: Dictionary of loss components.
835789 """
790+ x , y = self ._apply_resizing (x , y )
836791 output_tensor , _ , _ = self ._get_losses (x = x , y = y )
837792 output = {}
838793 for key , value in output_tensor .items ():
@@ -859,6 +814,7 @@ def compute_loss( # type: ignore
859814 """
860815 import torch
861816
817+ x , y = self ._apply_resizing (x , y )
862818 output , _ , _ = self ._get_losses (x = x , y = y )
863819
864820 # Compute the gradient and return
@@ -876,6 +832,90 @@ def compute_loss( # type: ignore
876832
877833 return loss .detach ().cpu ().numpy ()
878834
835+ def _apply_resizing (self , x : Union [np .ndarray , "torch.Tensor" ],
836+ y : List [Dict [str , Union [np .ndarray , "torch.Tensor" ]]],
837+ height : int = 800 ,
838+ width : int = 800 ):
839+ """
840+ Resize the input and targets to dimensions expected by DETR.
841+
842+ :param x: Array or Tensor representing images of any size
843+ :param y: List of targets to be transformed
844+ :param height: Int representing desired height, the default is compatible with DETR
845+ :param width: Int representing desired width, the default is compatible with DETR
846+ """
847+ import cv2
848+ import torchvision .transforms as T
849+ import torch
850+
851+ if (
852+ self ._input_shape [1 ] < self .MIN_IMAGE_SIZE
853+ or self ._input_shape [1 ] > self .MAX_IMAGE_SIZE
854+ or self ._input_shape [2 ] < self .MIN_IMAGE_SIZE
855+ or self .input_shape [2 ] > self .MAX_IMAGE_SIZE
856+ ):
857+ resized_imgs = []
858+ if isinstance (x , torch .Tensor ):
859+ x = T .Resize (size = (height , width ))(x )
860+ else :
861+ for i , _ in enumerate (x ):
862+ resized = cv2 .resize (
863+ (x )[i ].transpose (1 , 2 , 0 ),
864+ dsize = (height , width ),
865+ interpolation = cv2 .INTER_CUBIC ,
866+ )
867+ resized = resized .transpose (2 , 0 , 1 )
868+ resized_imgs .append (
869+ resized
870+ )
871+ x = np .array (resized_imgs )
872+
873+ elif self ._input_shape [1 ] != self ._input_shape [2 ]:
874+ rescale_dim = max (self ._input_shape [1 ], self ._input_shape [2 ])
875+ resized_imgs = []
876+ if isinstance (x , torch .Tensor ):
877+ x = T .Resize (size = (rescale_dim ,rescale_dim ))(x )
878+ else :
879+ for i , _ in enumerate (x ):
880+ resized = cv2 .resize (
881+ (x )[i ].transpose (1 , 2 , 0 ),
882+ dsize = (rescale_dim , rescale_dim ),
883+ interpolation = cv2 .INTER_CUBIC ,
884+ )
885+ resized = resized .transpose (2 , 0 , 1 )
886+ resized_imgs .append (
887+ resized
888+ )
889+ x = np .array (resized_imgs )
890+
891+ targets = []
892+ if y is not None :
893+ if isinstance (y [0 ]['boxes' ], torch .Tensor ):
894+ for target in y :
895+ cxcy_norm = revert_rescale_bboxes (
896+ target ["boxes" ], (self .input_shape [2 ], self .input_shape [1 ])
897+ )
898+ targets .append (
899+ {
900+ "labels" : target ["labels" ].type (torch .int64 ).to (self .device ),
901+ "boxes" : cxcy_norm .to (self .device ),
902+ "scores" : target ["scores" ].type (torch .float ).to (self .device ),
903+ }
904+ )
905+ else :
906+ for target in y :
907+ cxcy_norm = revert_rescale_bboxes (
908+ torch .from_numpy (target ["boxes" ]), (self .input_shape [2 ], self .input_shape [1 ])
909+ )
910+ targets .append (
911+ {
912+ "labels" : torch .from_numpy (target ["labels" ]).type (torch .int64 ).to (self .device ),
913+ "boxes" : cxcy_norm .to (self .device ),
914+ "scores" : torch .from_numpy (target ["scores" ]).type (torch .float ).to (self .device ),
915+ }
916+ )
917+
918+ return x , targets
879919
880920class NestedTensor :
881921 """
0 commit comments