@@ -266,7 +266,9 @@ def predict( # pylint: disable=W0221
266266
267267 return predictions
268268
269- def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 10 , ** kwargs ) -> None :
269+ def fit (
270+ self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 10 , verbose : bool = False , ** kwargs
271+ ) -> None :
270272 """
271273 Fit the classifier on the training set `(x, y)`.
272274
@@ -275,6 +277,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
275277 shape (nb_samples,).
276278 :param batch_size: Size of batches.
277279 :param nb_epochs: Number of epochs to use for training.
280+ :param verbose: If to display the progress bar information.
278281 :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
279282 TensorFlow and providing it takes no effect.
280283 """
@@ -298,12 +301,12 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
298301 ind = np .arange (len (x_preprocessed )).tolist ()
299302
300303 # Start training
301- for _ in range (nb_epochs ):
304+ for _ in tqdm ( range (nb_epochs ), disable = not verbose , desc = "Epochs" ):
302305 # Shuffle the examples
303306 random .shuffle (ind )
304307
305308 # Train for one epoch
306- for m in range (num_batch ):
309+ for m in tqdm ( range (num_batch ), disable = not verbose , desc = "Batches" ):
307310 i_batch = x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
308311 o_batch = y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
309312
@@ -314,13 +317,14 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
314317 # Run train step
315318 self ._sess .run (self .train , feed_dict = feed_dict )
316319
317- def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , ** kwargs ) -> None :
320+ def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , verbose : bool = False , ** kwargs ) -> None :
318321 """
319322 Fit the classifier using the generator that yields batches as specified.
320323
321324 :param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
322325 training in TensorFlow, it will.
323326 :param nb_epochs: Number of epochs to use for training.
327+ :param verbose: If to display the progress bar information.
324328 :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
325329 TensorFlow and providing it takes no effect.
326330 """
@@ -343,8 +347,8 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
343347 == (0 , 1 )
344348 )
345349 ):
346- for _ in range (nb_epochs ):
347- for _ in range (int (generator .size / generator .batch_size )): # type: ignore
350+ for _ in tqdm ( range (nb_epochs ), disable = not verbose , desc = "Epochs" ):
351+ for _ in tqdm ( range (int (generator .size / generator .batch_size )), disable = not verbose , desc = "Batches" ): # type: ignore
348352 i_batch , o_batch = generator .get_batch ()
349353
350354 if self ._reduce_labels :
@@ -953,7 +957,9 @@ def _predict_framework(self, x: "tf.Tensor", training_mode: bool = False) -> "tf
953957
954958 return self ._model (x_preprocessed , training = training_mode )
955959
956- def fit (self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 10 , ** kwargs ) -> None :
960+ def fit (
961+ self , x : np .ndarray , y : np .ndarray , batch_size : int = 128 , nb_epochs : int = 10 , verbose : bool = False , ** kwargs
962+ ) -> None :
957963 """
958964 Fit the classifier on the training set `(x, y)`.
959965
@@ -962,14 +968,13 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
962968 shape (nb_samples,).
963969 :param batch_size: Size of batches.
964970 :param nb_epochs: Number of epochs to use for training.
971+ :param verbose: If to display progress bar information.
965972 :param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
966973 "scheduler" which is an optional function that will be called at the end of every
967- epoch to adjust the learning rate, and "display_progress_bar" to display training progress .
974+ epoch to adjust the learning rate.
968975 """
969976 import tensorflow as tf
970977
971- display_progress_bar = kwargs .get ("display_progress_bar" , False )
972-
973978 if self ._train_step is None : # pragma: no cover
974979 if self ._loss_object is None : # pragma: no cover
975980 raise TypeError (
@@ -1006,29 +1011,28 @@ def train_step(model, images, labels):
10061011
10071012 train_ds = tf .data .Dataset .from_tensor_slices ((x_preprocessed , y_preprocessed )).shuffle (10000 ).batch (batch_size )
10081013
1009- for epoch in tqdm (range (nb_epochs ), disable = not display_progress_bar , desc = "Epochs" ):
1010- for images , labels in tqdm (train_ds , disable = not display_progress_bar , desc = "Batches" ):
1014+ for epoch in tqdm (range (nb_epochs ), disable = not verbose , desc = "Epochs" ):
1015+ for images , labels in tqdm (train_ds , disable = not verbose , desc = "Batches" ):
10111016 train_step (self .model , images , labels )
10121017
10131018 if scheduler is not None :
10141019 scheduler (epoch )
10151020
1016- def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , ** kwargs ) -> None :
1021+ def fit_generator (self , generator : "DataGenerator" , nb_epochs : int = 20 , verbose : bool = False , ** kwargs ) -> None :
10171022 """
10181023 Fit the classifier using the generator that yields batches as specified.
10191024
10201025 :param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
10211026 training in TensorFlow, it will.
10221027 :param nb_epochs: Number of epochs to use for training.
1028+ :param verbose: If to display progress bar information
10231029 :param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
10241030 "scheduler" which is an optional function that will be called at the end of every
1025- epoch to adjust the learning rate, and "display_progress_bar" to display training progress .
1031+ epoch to adjust the learning rate.
10261032 """
10271033 import tensorflow as tf
10281034 from art .data_generators import TensorFlowV2DataGenerator
10291035
1030- display_progress_bar = kwargs .get ("display_progress_bar" , False )
1031-
10321036 if self ._train_step is None : # pragma: no cover
10331037 if self ._loss_object is None : # pragma: no cover
10341038 raise TypeError (
@@ -1068,8 +1072,8 @@ def train_step(model, images, labels):
10681072 == (0 , 1 )
10691073 )
10701074 ):
1071- for epoch in tqdm (range (nb_epochs ), disable = not display_progress_bar , desc = "Epochs" ):
1072- for i_batch , o_batch in tqdm (generator .iterator , disable = not display_progress_bar , desc = "Batches" ):
1075+ for epoch in tqdm (range (nb_epochs ), disable = not verbose , desc = "Epochs" ):
1076+ for i_batch , o_batch in tqdm (generator .iterator , disable = not verbose , desc = "Batches" ):
10731077 if self ._reduce_labels :
10741078 o_batch = tf .math .argmax (o_batch , axis = 1 )
10751079 train_step (self ._model , i_batch , o_batch )
0 commit comments