2626import numpy as np
2727
2828from art .estimators .object_detection .object_detector import ObjectDetectorMixin
29+ from art .estimators .object_detection .utils import cast_inputs_to_pt
2930from art .estimators .pytorch import PyTorchEstimator
3031
3132if TYPE_CHECKING :
@@ -296,28 +297,12 @@ def _preprocess_and_convert_inputs(
296297 norm_factor = 1.0
297298
298299 if self .all_framework_preprocessing :
299- if isinstance (x , np .ndarray ):
300- # Convert samples into tensor
301- x_tensor = torch .from_numpy (x / norm_factor ).to (self .device )
302- else :
303- x_tensor = (x / norm_factor ).to (self .device )
300+ # Convert samples into tensor
301+ x_tensor , y_tensor = cast_inputs_to_pt (x , y )
304302
305303 if not self .channels_first :
306304 x_tensor = torch .permute (x_tensor , (0 , 3 , 1 , 2 ))
307-
308- # Convert targets into tensor
309- if y is not None and isinstance (y [0 ]["boxes" ], np .ndarray ):
310- y_tensor = []
311- for y_i in y :
312- y_t = {
313- "boxes" : torch .from_numpy (y_i ["boxes" ]).to (device = self .device , dtype = torch .float32 ),
314- "labels" : torch .from_numpy (y_i ["labels" ]).to (device = self .device , dtype = torch .int64 ),
315- }
316- if "masks" in y_i :
317- y_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).to (device = self .device , dtype = torch .uint8 )
318- y_tensor .append (y_t )
319- else :
320- y_tensor = y # type: ignore
305+ x_tensor /= norm_factor
321306
322307 # Set gradients
323308 if not no_grad :
@@ -328,33 +313,19 @@ def _preprocess_and_convert_inputs(
328313
329314 elif isinstance (x , np .ndarray ):
330315 # Apply preprocessing
331- x_preprocessed , y_preprocessed = self ._apply_preprocessing (x , y = y , fit = fit , no_grad = no_grad )
316+ x_preprocessed , y_preprocessed = self ._apply_preprocessing (x = x , y = y , fit = fit , no_grad = no_grad )
332317
333- # Convert samples into tensor
334- x_preprocessed = torch . from_numpy (x_preprocessed / norm_factor ). to ( self . device )
318+ # Convert inputs into tensor
319+ x_preprocessed , y_preprocessed = cast_inputs_to_pt (x_preprocessed , y_preprocessed )
335320
336321 if not self .channels_first :
337322 x_preprocessed = torch .permute (x_preprocessed , (0 , 3 , 1 , 2 ))
323+ x_preprocessed /= norm_factor
338324
339325 # Set gradients
340326 if not no_grad :
341327 x_preprocessed .requires_grad = True
342328
343- # Convert targets into tensor
344- if y_preprocessed is not None and isinstance (y_preprocessed [0 ]["boxes" ], np .ndarray ):
345- y_preprocessed_tensor = []
346- for y_i in y_preprocessed :
347- y_preprocessed_t = {
348- "boxes" : torch .from_numpy (y_i ["boxes" ]).to (device = self .device , dtype = torch .float32 ),
349- "labels" : torch .from_numpy (y_i ["labels" ]).to (device = self .device , dtype = torch .int64 ),
350- }
351- if "masks" in y_i :
352- y_preprocessed_t ["masks" ] = torch .from_numpy (y_i ["masks" ]).to (
353- device = self .device , dtype = torch .uint8
354- )
355- y_preprocessed_tensor .append (y_preprocessed_t )
356- y_preprocessed = y_preprocessed_tensor
357-
358329 else :
359330 raise NotImplementedError ("Combination of inputs and preprocessing not supported." )
360331
@@ -380,6 +351,7 @@ def _get_losses(
380351 x_preprocessed , y_preprocessed = self ._preprocess_and_convert_inputs (x = x , y = y , fit = False , no_grad = False )
381352 x_grad = x_preprocessed
382353
354+ # Extract height and width
383355 if self .channels_first :
384356 height = self .input_shape [1 ]
385357 width = self .input_shape [2 ]
@@ -389,7 +361,7 @@ def _get_losses(
389361
390362 labels_t = translate_labels_x1y1x2y2_to_xcycwh (labels_x1y1x2y2 = y_preprocessed , height = height , width = width )
391363
392- loss_components = self ._model (x_grad , labels_t )
364+ loss_components = self ._model (x_grad . to ( self . device ) , labels_t . to ( self . device ) )
393365
394366 return loss_components , x_grad
395367
@@ -463,30 +435,34 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> List[Dict[s
463435 - scores [N]: the scores of each prediction.
464436 """
465437 import torch
438+ from torch .utils .data import TensorDataset , DataLoader
466439
467440 # Set model to evaluation mode
468441 self ._model .eval ()
469442
470443 # Apply preprocessing and convert to tensors
471444 x_preprocessed , _ = self ._preprocess_and_convert_inputs (x = x , y = None , fit = False , no_grad = True )
472445
473- predictions : List [Dict [str , np .ndarray ]] = []
446+ # Create dataloader
447+ dataset = TensorDataset (x_preprocessed )
448+ dataloader = DataLoader (dataset = dataset , batch_size = batch_size , shuffle = False )
474449
450+ # Extract height and width
475451 if self .channels_first :
476452 height = self .input_shape [1 ]
477453 width = self .input_shape [2 ]
478454 else :
479455 height = self .input_shape [0 ]
480456 width = self .input_shape [1 ]
481457
482- # Run prediction
483- num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
484- for m in range (num_batch ):
485- # Batch using indices
486- i_batch = x_preprocessed [m * batch_size : (m + 1 ) * batch_size ]
458+ predictions : List [Dict [str , np .ndarray ]] = []
459+ for (x_batch ,) in dataloader :
460+ # Move inputs to device
461+ x_batch = x_batch .to (self ._device )
487462
463+ # Run prediction
488464 with torch .no_grad ():
489- predictions_xcycwh = self ._model (i_batch )
465+ predictions_xcycwh = self ._model (x_batch . to ( self . device ) )
490466
491467 predictions_x1y1x2y2 = translate_predictions_xcycwh_to_x1y1x2y2 (
492468 y_pred_xcycwh = predictions_xcycwh , height = height , width = width
@@ -533,6 +509,8 @@ def fit( # pylint: disable=W0221
533509 :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
534510 and providing it takes no effect.
535511 """
512+ import torch
513+ from torch .utils .data import Dataset , DataLoader
536514
537515 # Set model to train mode
538516 self ._model .train ()
@@ -541,41 +519,54 @@ def fit( # pylint: disable=W0221
541519 raise ValueError ("An optimizer is needed to train the model, but none for provided." )
542520
543521 # Apply preprocessing and convert to tensors
544- x_preprocessed , y_preprocessed_list = self ._preprocess_and_convert_inputs (x = x , y = y , fit = True , no_grad = True )
545-
546- # Cast to np.ndarray to use list indexing
547- y_preprocessed = np .asarray (y_preprocessed_list )
522+ x_preprocessed , y_preprocessed = self ._preprocess_and_convert_inputs (x = x , y = y , fit = True , no_grad = True )
523+
524+ class ObjectDetectorDataset (Dataset ):
525+ def __init__ (self , x , y ):
526+ self .x = x
527+ self .y = y
528+
529+ def __len__ (self ):
530+ return len (self .x )
531+
532+ def __getitem__ (self , idx ):
533+ return self .x [idx ], self .y [idx ]
534+
535+ # Create dataloader
536+ dataset = ObjectDetectorDataset (x_preprocessed , y_preprocessed )
537+ dataloader = DataLoader (
538+ dataset = dataset ,
539+ batch_size = batch_size ,
540+ shuffle = True ,
541+ drop_last = drop_last ,
542+ collate_fn = lambda batch : list (zip (* batch )),
543+ )
548544
545+ # Extract height and width
549546 if self .channels_first :
550547 height = self .input_shape [1 ]
551548 width = self .input_shape [2 ]
552549 else :
553550 height = self .input_shape [0 ]
554551 width = self .input_shape [1 ]
555552
556- num_batch = len (x_preprocessed ) / float (batch_size )
557- if drop_last :
558- num_batch = int (np .floor (num_batch ))
559- else :
560- num_batch = int (np .ceil (num_batch ))
561- ind = np .arange (len (x_preprocessed ))
562-
563553 # Start training
564554 for _ in range (nb_epochs ):
565- # Shuffle the examples
566- np .random .shuffle (ind )
567-
568555 # Train for one epoch
569- for m in range (num_batch ):
570- i_batch = x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
571- o_batch = y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
556+ for x_batch , y_batch in dataloader :
557+ # Convert labels to YOLO
558+ x_batch = torch .stack (x_batch )
559+ y_batch = translate_labels_x1y1x2y2_to_xcycwh (labels_x1y1x2y2 = y_batch , height = height , width = width )
560+
561+ # Move inputs to device
562+ x_batch = x_batch .to (self .device )
563+ y_batch = y_batch .to (self .device )
572564
573565 # Zero the parameter gradients
574566 self ._optimizer .zero_grad ()
575567
576568 # Form the loss function
577- labels_t = translate_labels_x1y1x2y2_to_xcycwh (labels_x1y1x2y2 = o_batch , height = height , width = width )
578- loss_components = self ._model (i_batch , labels_t )
569+ loss_components = self ._model (x_batch , y_batch )
579570 if isinstance (loss_components , dict ):
580571 loss = sum (loss_components .values ())
581572 else :
0 commit comments