@@ -559,7 +559,9 @@ def predict( # pylint: disable=W0221
559559
560560 return predictions
561561
562- def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , ** kwargs ) -> None :
562+ def fit (
563+ self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , verbose : bool = False , ** kwargs
564+ ) -> None :
563565 """
564566 Fit the classifier on the training set `(x, y)`.
565567
@@ -568,6 +570,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
568570 shape (nb_samples,).
569571 :param batch_size: Size of batches.
570572 :param nb_epochs: Number of epochs to use for training.
573+ :param verbose: Display training progress bar.
571574 :param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
572575 `fit_generator` function in Keras and will be passed to this function as such. Including the number of
573576 epochs or the number of steps per epoch as part of this argument will result in as error.
@@ -582,18 +585,18 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
582585 if self ._reduce_labels or y_ndim == 1 :
583586 y_preprocessed = np .argmax (y_preprocessed , axis = 1 )
584587
585- if "verbose" in kwargs :
586- kwargs ["verbose" ] = int (kwargs ["verbose" ])
587-
588- self ._model .fit (x = x_preprocessed , y = y_preprocessed , batch_size = batch_size , epochs = nb_epochs , ** kwargs )
588+ self ._model .fit (
589+ x = x_preprocessed , y = y_preprocessed , batch_size = batch_size , epochs = nb_epochs , verbose = int (verbose ), ** kwargs
590+ )
589591
590- def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , ** kwargs ) -> None :
592+ def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , verbose : bool = False , ** kwargs ) -> None :
591593 """
592594 Fit the classifier using the generator that yields batches as specified.
593595
594596 :param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
595597 training in Keras, it will.
596598 :param nb_epochs: Number of epochs to use for training.
599+ :param verbose: Display training progress bar.
597600 :param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
598601 `fit_generator` function in Keras and will be passed to this function as such. Including the number of
599602 epochs as part of this argument will result in as error.
@@ -603,9 +606,6 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
603606 # Try to use the generator as a Keras native generator, otherwise use it through the `DataGenerator` interface
604607 from art .preprocessing .standardisation_mean_std .numpy import StandardisationMeanStd
605608
606- if "verbose" in kwargs :
607- kwargs ["verbose" ] = int (kwargs ["verbose" ])
608-
609609 if isinstance (generator , KerasDataGenerator ) and (
610610 self .preprocessing is None
611611 or (
@@ -618,12 +618,12 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
618618 )
619619 ):
620620 try :
621- self ._model .fit_generator (generator .iterator , epochs = nb_epochs , ** kwargs )
621+ self ._model .fit_generator (generator .iterator , epochs = nb_epochs , verbose = int ( verbose ), ** kwargs )
622622 except ValueError : # pragma: no cover
623623 logger .info ("Unable to use data generator as Keras generator. Now treating as framework-independent." )
624- super ().fit_generator (generator , nb_epochs = nb_epochs , ** kwargs )
624+ super ().fit_generator (generator , nb_epochs = nb_epochs , verbose = verbose , ** kwargs )
625625 else : # pragma: no cover
626- super ().fit_generator (generator , nb_epochs = nb_epochs , ** kwargs )
626+ super ().fit_generator (generator , nb_epochs = nb_epochs , verbose = verbose , ** kwargs )
627627
628628 def get_activations (
629629 self , x : np .ndarray , layer : Union [int , str ], batch_size : int = 128 , framework : bool = False
0 commit comments