2121 | Paper link: https://arxiv.org/abs/2005.12872
2222"""
2323import logging
24- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union
24+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union , Any
2525
2626import numpy as np
2727
@@ -581,10 +581,10 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
581581 import torch
582582
583583 self ._model .eval ()
584- x , _ = self ._apply_resizing (x , None )
584+ x_resized , _ = self ._apply_resizing (x )
585585
586586 # Apply preprocessing
587- x_preprocessed , _ = self ._apply_preprocessing (x , y = None , fit = False )
587+ x_preprocessed , _ = self ._apply_preprocessing (x_resized , y = None , fit = False )
588588
589589 if self .clip_values is not None :
590590 norm_factor = self .clip_values [1 ]
@@ -644,6 +644,7 @@ def _get_losses(
644644
645645 # Apply preprocessing
646646 if self .all_framework_preprocessing :
647+ print (y )
647648 if y is not None and isinstance (y , list ) and isinstance (y [0 ]["boxes" ], np .ndarray ):
648649 y_tensor = []
649650 for y_i in y :
@@ -733,15 +734,15 @@ def loss_gradient(
733734 - labels (Tensor[N]): the predicted labels for each image
734735 :return: Loss gradients of the same shape as `x`.
735736 """
736- x , y = self ._apply_resizing (x , y )
737- output , inputs_t , image_tensor_list_grad = self ._get_losses (x = x , y = y )
737+ x_resized , y_resized = self ._apply_resizing (x , y )
738+ output , inputs_t , image_tensor_list_grad = self ._get_losses (x = x_resized , y = y_resized )
738739 loss = sum (output [k ] * self .weight_dict [k ] for k in output .keys () if k in self .weight_dict )
739740
740741 self ._model .zero_grad ()
741742
742743 loss .backward (retain_graph = True ) # type: ignore
743744
744- if isinstance (x , np .ndarray ):
745+ if isinstance (x_resized , np .ndarray ):
745746 if image_tensor_list_grad .grad is not None :
746747 grads = image_tensor_list_grad .grad .cpu ().numpy ().copy ()
747748 else :
@@ -756,9 +757,7 @@ def loss_gradient(
756757 grads = grads / self .clip_values [1 ]
757758
758759 if not self .all_framework_preprocessing :
759- grads = self ._apply_preprocessing_gradient (x , grads )
760-
761- assert grads .shape == x .shape
760+ grads = self ._apply_preprocessing_gradient (x_resized , grads )
762761
763762 return grads
764763
@@ -787,8 +786,8 @@ def compute_losses(
787786 - scores (Tensor[N]): the scores or each prediction.
788787 :return: Dictionary of loss components.
789788 """
790- x , y = self ._apply_resizing (x , y )
791- output_tensor , _ , _ = self ._get_losses (x = x , y = y )
789+ x_resized , y = self ._apply_resizing (x , y )
790+ output_tensor , _ , _ = self ._get_losses (x = x_resized , y = y )
792791 output = {}
793792 for key , value in output_tensor .items ():
794793 if key in self .attack_losses :
@@ -824,7 +823,6 @@ def compute_loss( # type: ignore
824823 loss = output [loss_name ]
825824 else :
826825 loss = loss + output [loss_name ]
827-
828826 assert loss is not None
829827
830828 if isinstance (x , torch .Tensor ):
@@ -835,10 +833,10 @@ def compute_loss( # type: ignore
835833 def _apply_resizing (
836834 self ,
837835 x : Union [np .ndarray , "torch.Tensor" ],
838- y : List [ Dict [ str , Union [ np . ndarray , "torch.Tensor" ]]] ,
836+ y : Any = None ,
839837 height : int = 800 ,
840838 width : int = 800 ,
841- ):
839+ ) -> Tuple [ Union [ np . ndarray , "torch.Tensor" ], List [ Any ]] :
842840 """
843841 Resize the input and targets to dimensions expected by DETR.
844842
@@ -861,9 +859,9 @@ def _apply_resizing(
861859 if isinstance (x , torch .Tensor ):
862860 x = T .Resize (size = (height , width ))(x )
863861 else :
864- for i , _ in enumerate ( x ) :
862+ for i in x :
865863 resized = cv2 .resize (
866- ( x )[ i ] .transpose (1 , 2 , 0 ),
864+ i .transpose (1 , 2 , 0 ),
867865 dsize = (height , width ),
868866 interpolation = cv2 .INTER_CUBIC ,
869867 )
@@ -877,20 +875,23 @@ def _apply_resizing(
877875 if isinstance (x , torch .Tensor ):
878876 x = T .Resize (size = (rescale_dim , rescale_dim ))(x )
879877 else :
880- for i , _ in enumerate ( x ) :
878+ for i in x :
881879 resized = cv2 .resize (
882- ( x )[ i ] .transpose (1 , 2 , 0 ),
880+ i .transpose (1 , 2 , 0 ),
883881 dsize = (rescale_dim , rescale_dim ),
884882 interpolation = cv2 .INTER_CUBIC ,
885883 )
886884 resized = resized .transpose (2 , 0 , 1 )
887885 resized_imgs .append (resized )
888886 x = np .array (resized_imgs )
889887
890- targets = []
888+ targets : List [ Any ] = []
891889 if y is not None :
892890 if isinstance (y [0 ]["boxes" ], torch .Tensor ):
893891 for target in y :
892+ assert isinstance (target ["boxes" ], torch .Tensor )
893+ assert isinstance (target ["labels" ], torch .Tensor )
894+ assert isinstance (target ["scores" ], torch .Tensor )
894895 cxcy_norm = revert_rescale_bboxes (target ["boxes" ], (self .input_shape [2 ], self .input_shape [1 ]))
895896 targets .append (
896897 {
@@ -901,9 +902,8 @@ def _apply_resizing(
901902 )
902903 else :
903904 for target in y :
904- cxcy_norm = revert_rescale_bboxes (
905- torch .from_numpy (target ["boxes" ]), (self .input_shape [2 ], self .input_shape [1 ])
906- )
905+ tensor_box = torch .from_numpy (target ["boxes" ])
906+ cxcy_norm = revert_rescale_bboxes (tensor_box , (self .input_shape [2 ], self .input_shape [1 ]))
907907 targets .append (
908908 {
909909 "labels" : torch .from_numpy (target ["labels" ]).type (torch .int64 ).to (self .device ),
@@ -988,11 +988,9 @@ def grad_enabled_forward(self, samples: NestedTensor):
988988 if isinstance (samples , (list , torch .Tensor )):
989989 samples = nested_tensor_from_tensor_list (samples )
990990 features , pos = self .backbone (samples )
991-
992991 src , mask = features [- 1 ].decompose ()
993992 assert mask is not None
994993 h_s = self .transformer (self .input_proj (src ), mask , self .query_embed .weight , pos [- 1 ])[0 ]
995-
996994 outputs_class = self .class_embed (h_s )
997995 outputs_coord = self .bbox_embed (h_s ).sigmoid ()
998996 out = {"pred_logits" : outputs_class [- 1 ], "pred_boxes" : outputs_coord [- 1 ]}
0 commit comments