@@ -167,7 +167,7 @@ def _get_losses(
167167 x : np .ndarray ,
168168 y : Union [List [Dict [str , np .ndarray ]], List [Dict [str , "torch.Tensor" ]]],
169169 reduction : str = "sum" ,
170- ) -> Tuple [Dict [str , "torch.Tensor" ], List ["torch.Tensor" ], List ["torch.Tensor" ]]:
170+ ) -> Tuple [Dict [str , Union [ "torch.Tensor" , int , List [ "torch.Tensor" ]] ], List ["torch.Tensor" ], List ["torch.Tensor" ]]:
171171 """
172172 Get the loss tensor output of the model including all preprocessing.
173173
@@ -197,18 +197,18 @@ def _get_losses(
197197 y_tensor = list ()
198198 for i , y_i in enumerate (y ):
199199 y_t = dict ()
200- y_t ["boxes" ] = torch .from_numpy (y_i ["boxes" ]).type ( torch . float ).to (self ._device )
200+ y_t ["boxes" ] = torch .from_numpy (y_i ["boxes" ]).float ( ).to (self ._device )
201201 if "labels" in y_i :
202- y_t ["labels" ] = torch .from_numpy (y_i ["labels" ]).type ( torch . int64 ).to (self ._device )
202+ y_t ["labels" ] = torch .from_numpy (y_i ["labels" ]).int ( ).to (self ._device )
203203 if "masks" in y_i :
204- y_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).type ( torch . int64 ).to (self ._device )
204+ y_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).int ( ).to (self ._device )
205205 y_tensor .append (y_t )
206206 else :
207207 y_tensor = y
208208
209209 image_tensor_list_grad = list ()
210210 y_preprocessed = list ()
211- inputs_t = list ()
211+ inputs_t : List [ "torch.Tensor" ] = list ()
212212
213213 for i in range (x .shape [0 ]):
214214 if self .clip_values is not None :
@@ -246,14 +246,15 @@ def _get_losses(
246246 loss = torch .nn .L1Loss (size_average = False )(y_pred .float (), gt_bb .float ())
247247 loss_list .append (loss )
248248
249+ loss_dict : Dict [str , Union ["torch.Tensor" , int , List ["torch.Tensor" ]]] = dict ()
249250 if reduction == "sum" :
250- loss = { "torch.nn.L1Loss" : sum (loss_list )}
251+ loss_dict [ "torch.nn.L1Loss" ] = sum (loss_list )
251252 elif reduction == "none" :
252- loss = { "torch.nn.L1Loss" : loss_list }
253+ loss_dict [ "torch.nn.L1Loss" ] = loss_list
253254 else :
254255 raise ValueError ("Reduction not recognised." )
255256
256- return loss , inputs_t , image_tensor_list_grad
257+ return loss_dict , inputs_t , image_tensor_list_grad
257258
258259 def loss_gradient ( # pylint: disable=W0613
259260 self , x : np .ndarray , y : Union [List [Dict [str , np .ndarray ]], List [Dict [str , "torch.Tensor" ]]], ** kwargs
@@ -295,7 +296,10 @@ def loss_gradient( # pylint: disable=W0613
295296 loss .backward (retain_graph = True ) # type: ignore
296297
297298 for img in image_tensor_list_grad :
298- gradients = img .grad .cpu ().numpy ().copy ()
299+ if img .grad is not None :
300+ gradients = img .grad .cpu ().numpy ().copy ()
301+ else :
302+ gradients = None
299303 grad_list .append (gradients )
300304
301305 grads = np .array (grad_list )
@@ -325,8 +329,14 @@ def _preprocess(self, img: "torch.Tensor") -> "torch.Tensor":
325329 import torch # lgtm [py/repeated-import]
326330 from torch .nn .functional import interpolate
327331
328- mean_np = self .preprocessing .mean
329- std_np = self .preprocessing .std
332+ from art .preprocessing .standardisation_mean_std .pytorch import StandardisationMeanStdPyTorch
333+
334+ if self .preprocessing is not None and isinstance (self .preprocessing , StandardisationMeanStdPyTorch ):
335+ mean_np = self .preprocessing .mean
336+ std_np = self .preprocessing .std
337+ else :
338+ mean_np = np .ones ((3 , 1 , 1 ))
339+ std_np = np .ones ((3 , 1 , 1 ))
330340 mean = torch .from_numpy (mean_np ).reshape ((3 , 1 , 1 ))
331341 std = torch .from_numpy (std_np ).reshape ((3 , 1 , 1 ))
332342 img = img .permute (2 , 0 , 1 )
@@ -353,7 +363,7 @@ def _track_step(
353363
354364 k_context_factor = 2
355365
356- def compute_output_height_f (bbox_tight : "torch.Tensor" ) -> float :
366+ def compute_output_height_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
357367 """
358368 Compute height of search/target region.
359369
@@ -363,9 +373,9 @@ def compute_output_height_f(bbox_tight: "torch.Tensor") -> float:
363373 bbox_height = bbox_tight [3 ] - bbox_tight [1 ]
364374 output_height = k_context_factor * bbox_height
365375
366- return max (1.0 , output_height )
376+ return torch . max (torch . tensor ( 1.0 ). to ( self . device ) , output_height )
367377
368- def compute_output_width_f (bbox_tight : "torch.Tensor" ) -> float :
378+ def compute_output_width_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
369379 """
370380 Compute width of search/target region.
371381
@@ -375,7 +385,7 @@ def compute_output_width_f(bbox_tight: "torch.Tensor") -> float:
375385 bbox_width = bbox_tight [2 ] - bbox_tight [0 ]
376386 output_width = k_context_factor * bbox_width
377387
378- return max (1.0 , output_width )
388+ return torch . max (torch . tensor ( 1.0 ). to ( self . device ) , output_width )
379389
380390 def get_center_x_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
381391 """
@@ -397,7 +407,7 @@ def get_center_y_f(bbox_tight: "torch.Tensor") -> "torch.Tensor":
397407
398408 def compute_crop_pad_image_location (
399409 bbox_tight : "torch.Tensor" , image : "torch.Tensor"
400- ) -> ( float , float , float , float ) :
410+ ) -> Tuple [ "torch.Tensor" , "torch.Tensor" , "torch.Tensor" , "torch.Tensor" ] :
401411 """
402412 Get the valid image coordinates for the context region in target or search region in full image
403413
@@ -421,32 +431,32 @@ def compute_crop_pad_image_location(
421431 output_width = compute_output_width_f (bbox_tight )
422432 output_height = compute_output_height_f (bbox_tight )
423433
424- roi_left = max (0.0 , bbox_center_x - (output_width / 2.0 ))
425- roi_bottom = max (0.0 , bbox_center_y - (output_height / 2.0 ))
434+ roi_left = torch . max (torch . tensor ( 0.0 ). to ( self . device ) , bbox_center_x - (output_width / 2.0 ))
435+ roi_bottom = torch . max (torch . tensor ( 0.0 ). to ( self . device ) , bbox_center_y - (output_height / 2.0 ))
426436
427437 # New ROI width
428438 # -------------
429439 # 1. left_half should not go out of bound on the left side of the
430440 # image
431441 # 2. right_half should not go out of bound on the right side of the
432442 # image
433- left_half = min (output_width / 2.0 , bbox_center_x )
443+ left_half = torch . min (output_width / 2.0 , bbox_center_x )
434444 right_half = min (output_width / 2.0 , image_width - bbox_center_x )
435445 roi_width = max (1.0 , left_half + right_half )
436446
437447 # New ROI height
438448 # Similar logic applied that is applied for 'New ROI width'
439- top_half = min (output_height / 2.0 , bbox_center_y )
440- bottom_half = min (output_height / 2.0 , image_height - bbox_center_y )
441- roi_height = max (1.0 , top_half + bottom_half )
449+ top_half = torch . min (output_height / 2.0 , bbox_center_y )
450+ bottom_half = torch . min (output_height / 2.0 , image_height - bbox_center_y )
451+ roi_height = torch . max (torch . tensor ( 1.0 ). to ( self . device ) , top_half + bottom_half )
442452
443453 # Padded image location in the original image
444454 # objPadImageLocation = BoundingBox(roi_left, roi_bottom, roi_left + roi_width, roi_bottom + roi_height)
445455 #
446456 # return objPadImageLocation
447457 return roi_left , roi_bottom , roi_left + roi_width , roi_bottom + roi_height
448458
449- def edge_spacing_x_f (bbox_tight : "torch.Tensor" ) -> float :
459+ def edge_spacing_x_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
450460 """
451461 Edge spacing X to take care of if search/target pad region goes out of bound.
452462
@@ -456,9 +466,9 @@ def edge_spacing_x_f(bbox_tight: "torch.Tensor") -> float:
456466 output_width = compute_output_width_f (bbox_tight )
457467 bbox_center_x = get_center_x_f (bbox_tight )
458468
459- return max (0.0 , (output_width / 2 ) - bbox_center_x )
469+ return torch . max (torch . tensor ( 0.0 ). to ( self . device ) , (output_width / 2 ) - bbox_center_x )
460470
461- def edge_spacing_y_f (bbox_tight : "torch.Tensor" ) -> float :
471+ def edge_spacing_y_f (bbox_tight : "torch.Tensor" ) -> "torch.Tensor" :
462472 """
463473 Edge spacing X to take care of if search/target pad region goes out of bound.
464474
@@ -468,9 +478,16 @@ def edge_spacing_y_f(bbox_tight: "torch.Tensor") -> float:
468478 output_height = compute_output_height_f (bbox_tight )
469479 bbox_center_y = get_center_y_f (bbox_tight )
470480
471- return max (0.0 , (output_height / 2 ) - bbox_center_y )
481+ return torch . max (torch . tensor ( 0.0 ). to ( self . device ) , (output_height / 2 ) - bbox_center_y )
472482
473- def crop_pad_image (bbox_tight : "torch.Tensor" , image : "torch.Tensor" ) -> ("torch.Tensor" , float , float , float ):
483+ def crop_pad_image (
484+ bbox_tight : "torch.Tensor" , image : "torch.Tensor"
485+ ) -> Tuple [
486+ "torch.Tensor" ,
487+ Tuple ["torch.Tensor" , "torch.Tensor" , "torch.Tensor" , "torch.Tensor" ],
488+ "torch.Tensor" ,
489+ "torch.Tensor" ,
490+ ]:
474491 """
475492 Around the bounding box, we define a extra context factor of 2, which we will crop from the original image.
476493
0 commit comments