2424import copy
2525import logging
2626import os
27- import random
2827import time
2928from typing import Any , Dict , List , Optional , Tuple , Union , TYPE_CHECKING
3029
@@ -309,26 +308,27 @@ def predict( # pylint: disable=W0221
309308 :return: Array of predictions of shape `(nb_inputs, nb_classes)`.
310309 """
311310 import torch
311+ from torch .utils .data import TensorDataset , DataLoader
312312
313313 # Set model mode
314314 self ._model .train (mode = training_mode )
315315
316316 # Apply preprocessing
317317 x_preprocessed , _ = self ._apply_preprocessing (x , y = None , fit = False )
318318
319- results_list = []
319+ # Create dataloader
320+ x_tensor = torch .from_numpy (x_preprocessed )
321+ dataset = TensorDataset (x_tensor )
322+ dataloader = DataLoader (dataset = dataset , batch_size = batch_size , shuffle = False )
320323
321- # Run prediction with batch processing
322- num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
323- for m in range (num_batch ):
324- # Batch indexes
325- begin , end = (
326- m * batch_size ,
327- min ((m + 1 ) * batch_size , x_preprocessed .shape [0 ]),
328- )
324+ results_list = []
325+ for (x_batch ,) in dataloader :
326+ # Move inputs to device
327+ x_batch = x_batch .to (self ._device )
329328
329+ # Run prediction
330330 with torch .no_grad ():
331- model_outputs = self ._model (torch . from_numpy ( x_preprocessed [ begin : end ]). to ( self . _device ) )
331+ model_outputs = self ._model (x_batch )
332332 output = model_outputs [- 1 ]
333333 output = output .detach ().cpu ().numpy ().astype (np .float32 )
334334 if len (output .shape ) == 1 :
@@ -373,7 +373,7 @@ def fit( # pylint: disable=W0221
373373 nb_epochs : int = 10 ,
374374 training_mode : bool = True ,
375375 drop_last : bool = False ,
376- scheduler : Optional [Any ] = None ,
376+ scheduler : Optional ["torch.optim.lr_scheduler._LRScheduler" ] = None ,
377377 ** kwargs ,
378378 ) -> None :
379379 """
@@ -393,6 +393,7 @@ def fit( # pylint: disable=W0221
393393 and providing it takes no effect.
394394 """
395395 import torch
396+ from torch .utils .data import TensorDataset , DataLoader
396397
397398 # Set model mode
398399 self ._model .train (mode = training_mode )
@@ -408,32 +409,25 @@ def fit( # pylint: disable=W0221
408409 # Check label shape
409410 y_preprocessed = self .reduce_labels (y_preprocessed )
410411
411- num_batch = len (x_preprocessed ) / float (batch_size )
412- if drop_last :
413- num_batch = int (np .floor (num_batch ))
414- else :
415- num_batch = int (np .ceil (num_batch ))
416- ind = np .arange (len (x_preprocessed ))
417-
418- x_preprocessed = torch .from_numpy (x_preprocessed ).to (self ._device )
419- y_preprocessed = torch .from_numpy (y_preprocessed ).to (self ._device )
412+ # Create dataloader
413+ x_tensor = torch .from_numpy (x_preprocessed )
414+ y_tensor = torch .from_numpy (y_preprocessed )
415+ dataset = TensorDataset (x_tensor , y_tensor )
416+ dataloader = DataLoader (dataset = dataset , batch_size = batch_size , shuffle = True , drop_last = drop_last )
420417
421418 # Start training
422419 for _ in range (nb_epochs ):
423- # Shuffle the examples
424- random .shuffle (ind )
425-
426- # Train for one epoch
427- for m in range (num_batch ):
428- i_batch = x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
429- o_batch = y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
420+ for x_batch , y_batch in dataloader :
421+ # Move inputs to device
422+ x_batch = x_batch .to (self ._device )
423+ y_batch = y_batch .to (self ._device )
430424
431425 # Zero the parameter gradients
432426 self ._optimizer .zero_grad ()
433427
434428 # Perform prediction
435429 try :
436- model_outputs = self ._model (i_batch )
430+ model_outputs = self ._model (x_batch )
437431 except ValueError as err :
438432 if "Expected more than 1 value per channel when training" in str (err ):
439433 logger .exception (
@@ -443,15 +437,14 @@ def fit( # pylint: disable=W0221
443437 raise err
444438
445439 # Form the loss function
446- loss = self ._loss (model_outputs [- 1 ], o_batch )
440+ loss = self ._loss (model_outputs [- 1 ], y_batch )
447441
448442 # Do training
449443 if self ._use_amp : # pragma: no cover
450444 from apex import amp # pylint: disable=E0611
451445
452446 with amp .scale_loss (loss , self ._optimizer ) as scaled_loss :
453447 scaled_loss .backward ()
454-
455448 else :
456449 loss .backward ()
457450
0 commit comments