@@ -135,14 +135,6 @@ def __init__(
135135 num_classes , matcher = matcher , weight_dict = self .weight_dict , eos_coef = eos_coef , losses = losses
136136 )
137137
138- # Set device
139- self ._device : torch .device
140- if device_type == "cpu" or not torch .cuda .is_available ():
141- self ._device = torch .device ("cpu" )
142- else : # pragma: no cover
143- cuda_idx = torch .cuda .current_device ()
144- self ._device = torch .device (f"cuda:{ cuda_idx } " )
145-
146138 self ._model .to (self ._device )
147139 self ._model .eval ()
148140 self .attack_losses : Tuple [str , ...] = attack_losses
@@ -208,7 +200,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
208200 predictions .append (
209201 {
210202 "boxes" : rescale_bboxes (
211- model_output ["pred_boxes" ][i , :, :], (self ._input_shape [2 ], self ._input_shape [1 ])
203+ model_output ["pred_boxes" ][i , :, :]. cpu () , (self ._input_shape [2 ], self ._input_shape [1 ])
212204 )
213205 .detach ()
214206 .numpy (),
@@ -217,12 +209,14 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
217209 .softmax (- 1 )[0 , :, :- 1 ]
218210 .max (dim = 1 )[1 ]
219211 .detach ()
212+ .cpu ()
220213 .numpy (),
221214 "scores" : model_output ["pred_logits" ][i , :, :]
222215 .unsqueeze (0 )
223216 .softmax (- 1 )[0 , :, :- 1 ]
224217 .max (dim = 1 )[0 ]
225218 .detach ()
219+ .cpu ()
226220 .numpy (),
227221 }
228222 )
@@ -278,7 +272,7 @@ def _get_losses(
278272 else :
279273 x_grad = x .to (self .device )
280274 if x_grad .shape [2 ] < x_grad .shape [0 ] and x_grad .shape [2 ] < x_grad .shape [1 ]:
281- x_grad = torch .permute (x_grad , (2 , 0 , 1 ))
275+ x_grad = torch .permute (x_grad , (2 , 0 , 1 )). to ( self . device )
282276
283277 image_tensor_list_grad = x_grad
284278 x_preprocessed , y_preprocessed = self ._apply_preprocessing (x_grad , y = y_tensor , fit = False , no_grad = False )
@@ -304,7 +298,9 @@ def _get_losses(
304298 else :
305299 y_tensor = y # type: ignore
306300
307- x_preprocessed , y_preprocessed = self ._apply_preprocessing (x , y = y_tensor , fit = False , no_grad = True )
301+ x_preprocessed , y_preprocessed = self ._apply_preprocessing (
302+ x .to (self .device ), y = y_tensor , fit = False , no_grad = True
303+ )
308304
309305 if self .clip_values is not None :
310306 norm_factor = self .clip_values [1 ]
@@ -462,7 +458,7 @@ def _apply_resizing(
462458 ):
463459 resized_imgs = []
464460 if isinstance (x , torch .Tensor ):
465- x = T .Resize (size = (height , width ))(x )
461+ x = T .Resize (size = (height , width ))(x ). to ( self . device )
466462 else :
467463 for i in x :
468464 resized = cv2 .resize (
@@ -478,7 +474,7 @@ def _apply_resizing(
478474 rescale_dim = max (self ._input_shape [1 ], self ._input_shape [2 ])
479475 resized_imgs = []
480476 if isinstance (x , torch .Tensor ):
481- x = T .Resize (size = (rescale_dim , rescale_dim ))(x )
477+ x = T .Resize (size = (rescale_dim , rescale_dim ))(x ). to ( self . device )
482478 else :
483479 for i in x :
484480 resized = cv2 .resize (
0 commit comments