@@ -559,7 +559,9 @@ def predict( # pylint: disable=W0221
559
559
560
560
return predictions
561
561
562
- def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , ** kwargs ) -> None :
562
+ def fit (
563
+ self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 20 , verbose : bool = False , ** kwargs
564
+ ) -> None :
563
565
"""
564
566
Fit the classifier on the training set `(x, y)`.
565
567
@@ -568,6 +570,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
568
570
shape (nb_samples,).
569
571
:param batch_size: Size of batches.
570
572
:param nb_epochs: Number of epochs to use for training.
573
+ :param verbose: Display training progress bar.
571
574
:param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
572
575
`fit_generator` function in Keras and will be passed to this function as such. Including the number of
573
576
epochs or the number of steps per epoch as part of this argument will result in as error.
@@ -582,18 +585,18 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
582
585
if self ._reduce_labels or y_ndim == 1 :
583
586
y_preprocessed = np .argmax (y_preprocessed , axis = 1 )
584
587
585
- if "verbose" in kwargs :
586
- kwargs ["verbose" ] = int (kwargs ["verbose" ])
587
-
588
- self ._model .fit (x = x_preprocessed , y = y_preprocessed , batch_size = batch_size , epochs = nb_epochs , ** kwargs )
588
+ self ._model .fit (
589
+ x = x_preprocessed , y = y_preprocessed , batch_size = batch_size , epochs = nb_epochs , verbose = int (verbose ), ** kwargs
590
+ )
589
591
590
- def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , ** kwargs ) -> None :
592
+ def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , verbose : bool = False , ** kwargs ) -> None :
591
593
"""
592
594
Fit the classifier using the generator that yields batches as specified.
593
595
594
596
:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
595
597
training in Keras, it will.
596
598
:param nb_epochs: Number of epochs to use for training.
599
+ :param verbose: Display training progress bar.
597
600
:param kwargs: Dictionary of framework-specific arguments. These should be parameters supported by the
598
601
`fit_generator` function in Keras and will be passed to this function as such. Including the number of
599
602
epochs as part of this argument will result in as error.
@@ -603,9 +606,6 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
603
606
# Try to use the generator as a Keras native generator, otherwise use it through the `DataGenerator` interface
604
607
from art .preprocessing .standardisation_mean_std .numpy import StandardisationMeanStd
605
608
606
- if "verbose" in kwargs :
607
- kwargs ["verbose" ] = int (kwargs ["verbose" ])
608
-
609
609
if isinstance (generator , KerasDataGenerator ) and (
610
610
self .preprocessing is None
611
611
or (
@@ -618,12 +618,12 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
618
618
)
619
619
):
620
620
try :
621
- self ._model .fit_generator (generator .iterator , epochs = nb_epochs , ** kwargs )
621
+ self ._model .fit_generator (generator .iterator , epochs = nb_epochs , verbose = int ( verbose ), ** kwargs )
622
622
except ValueError : # pragma: no cover
623
623
logger .info ("Unable to use data generator as Keras generator. Now treating as framework-independent." )
624
- super ().fit_generator (generator , nb_epochs = nb_epochs , ** kwargs )
624
+ super ().fit_generator (generator , nb_epochs = nb_epochs , verbose = verbose , ** kwargs )
625
625
else : # pragma: no cover
626
- super ().fit_generator (generator , nb_epochs = nb_epochs , ** kwargs )
626
+ super ().fit_generator (generator , nb_epochs = nb_epochs , verbose = verbose , ** kwargs )
627
627
628
628
def get_activations (
629
629
self , x : np .ndarray , layer : Union [int , str ], batch_size : int = 128 , framework : bool = False
0 commit comments