2626import os
2727import time
2828from typing import Any , Dict , List , Optional , Tuple , Union , TYPE_CHECKING
29+ from tqdm .auto import tqdm
2930
3031import numpy as np
3132import six
@@ -389,12 +390,14 @@ def fit( # pylint: disable=W0221
389390 the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
390391 the last batch will be smaller. (default: ``False``)
391392 :param scheduler: Learning rate scheduler to run at the start of every epoch.
392- :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
393- and providing it takes no effect .
393+ :param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
394+ display training progress .
394395 """
395396 import torch
396397 from torch .utils .data import TensorDataset , DataLoader
397398
399+ display_progress_bar = kwargs .get ("display_progress_bar" , False )
400+
398401 # Set model mode
399402 self ._model .train (mode = training_mode )
400403
@@ -416,8 +419,8 @@ def fit( # pylint: disable=W0221
416419 dataloader = DataLoader (dataset = dataset , batch_size = batch_size , shuffle = True , drop_last = drop_last )
417420
418421 # Start training
419- for _ in range (nb_epochs ):
420- for x_batch , y_batch in dataloader :
422+ for _ in tqdm ( range (nb_epochs ), disable = not display_progress_bar , desc = "Epochs" ):
423+ for x_batch , y_batch in tqdm ( dataloader , disable = not display_progress_bar , desc = "Batches" ) :
421424 # Move inputs to device
422425 x_batch = x_batch .to (self ._device )
423426 y_batch = y_batch .to (self ._device )
@@ -459,12 +462,14 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
459462
460463 :param generator: Batch generator providing `(x, y)` for each epoch.
461464 :param nb_epochs: Number of epochs to use for training.
462- :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
463- and providing it takes no effect .
465+ :param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
466+ display training progress .
464467 """
465468 import torch
466469 from art .data_generators import PyTorchDataGenerator
467470
471+ display_progress_bar = kwargs .get ("display_progress_bar" , False )
472+
468473 # Put the model in the training mode
469474 self ._model .train ()
470475
@@ -485,8 +490,8 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
485490 == (0 , 1 )
486491 )
487492 ):
488- for _ in range (nb_epochs ):
489- for i_batch , o_batch in generator .iterator :
493+ for _ in tqdm ( range (nb_epochs ), disable = not display_progress_bar , desc = "Epochs" ):
494+ for i_batch , o_batch in tqdm ( generator .iterator , disable = not display_progress_bar , desc = "Batches" ) :
490495 if isinstance (i_batch , np .ndarray ):
491496 i_batch = torch .from_numpy (i_batch ).to (self ._device )
492497 else :
@@ -495,7 +500,10 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
495500 if isinstance (o_batch , np .ndarray ):
496501 o_batch = torch .argmax (torch .from_numpy (o_batch ).to (self ._device ), dim = 1 )
497502 else :
498- o_batch = torch .argmax (o_batch .to (self ._device ), dim = 1 )
503+ if o_batch .dim () > 1 :
504+ o_batch = torch .argmax (o_batch .to (self ._device ), dim = 1 )
505+ else :
506+ o_batch = o_batch .to (self ._device )
499507
500508 # Zero the parameter gradients
501509 self ._optimizer .zero_grad ()
0 commit comments