@@ -266,7 +266,35 @@ def predict( # pylint: disable=W0221
266266
267267 return predictions
268268
269- def fit (
269+ def process_verbose (self , verbose : Optional [Union [bool , int ]] = None ) -> bool :
270+ """
271+ Function to unify the various ways implemented in ART of displaying progress bars
272+ into a single True/False output.
273+ :param verbose: If to display the progress bar information.
274+ :return: True/False if to display the progress bars.
275+ """
276+
277+ if verbose is not None :
278+ if isinstance (verbose , int ):
279+ if verbose == 0 :
280+ display_pb = False
281+ else :
282+ display_pb = True
283+ elif isinstance (verbose , bool ):
284+ display_pb = verbose
285+ else :
286+ raise ValueError ("Verbose should be True/False or a 0/1 int" )
287+ else :
288+ # Check if the verbose attribute is present in the current classifier
289+ if hasattr (self , "verbose" ):
290+ display_pb = self .verbose
291+ # else default to False
292+ else :
293+ display_pb = False
294+
295+ return display_pb
296+
297+ def fit ( # pylint: disable=W0221
270298 self ,
271299 x : np .ndarray ,
272300 y : np .ndarray ,
@@ -290,15 +318,7 @@ def fit(
290318 if self .learning is not None :
291319 self .feed_dict [self .learning ] = True
292320
293- if verbose is None :
294- display_pb = False
295- elif isinstance (verbose , int ):
296- if verbose == 0 :
297- display_pb = False
298- else :
299- display_pb = True
300- else :
301- display_pb = verbose
321+ display_pb = self .process_verbose (verbose )
302322
303323 # Check if train and output_ph available
304324 if self .train is None or self .labels_ph is None : # pragma: no cover
@@ -333,7 +353,7 @@ def fit(
333353 # Run train step
334354 self ._sess .run (self .train , feed_dict = feed_dict )
335355
336- def fit_generator (
356+ def fit_generator ( # pylint: disable=W0221
337357 self , generator : "DataGenerator" , nb_epochs : int = 20 , verbose : Optional [Union [bool , int ]] = None , ** kwargs
338358 ) -> None :
339359 """
@@ -348,15 +368,7 @@ def fit_generator(
348368 """
349369 from art .data_generators import TensorFlowDataGenerator
350370
351- if verbose is None :
352- display_pb = False
353- elif isinstance (verbose , int ):
354- if verbose == 0 :
355- display_pb = False
356- else :
357- display_pb = True
358- else :
359- display_pb = verbose
371+ display_pb = self .process_verbose (verbose )
360372
361373 if self .learning is not None :
362374 self .feed_dict [self .learning ] = True
@@ -376,8 +388,13 @@ def fit_generator(
376388 )
377389 ):
378390 for _ in tqdm (range (nb_epochs ), disable = not display_pb , desc = "Epochs" ):
379- num_bathces = int (generator .size / generator .batch_size )
380- for _ in tqdm (range (num_bathces ), disable = not display_pb , desc = "Batches" ): # type: ignore
391+ gen_size = generator .size
392+ if isinstance (gen_size , int ):
393+ num_batchcs = int (gen_size / generator .batch_size )
394+ else :
395+ raise ValueError ("Number of batches could not be determined from the generator" )
396+
397+ for _ in tqdm (range (num_batchcs ), disable = not display_pb , desc = "Batches" ):
381398 i_batch , o_batch = generator .get_batch ()
382399
383400 if self ._reduce_labels :
@@ -986,6 +1003,34 @@ def _predict_framework(self, x: "tf.Tensor", training_mode: bool = False) -> "tf
9861003
9871004 return self ._model (x_preprocessed , training = training_mode )
9881005
1006+ def process_verbose (self , verbose : Optional [Union [bool , int ]] = None ) -> bool :
1007+ """
1008+ Function to unify the various ways implemented in ART of displaying progress bars
1009+ into a single True/False output.
1010+ :param verbose: If to display the progress bar information.
1011+ :return: True/False if to display the progress bars.
1012+ """
1013+
1014+ if verbose is not None :
1015+ if isinstance (verbose , int ):
1016+ if verbose == 0 :
1017+ display_pb = False
1018+ else :
1019+ display_pb = True
1020+ elif isinstance (verbose , bool ):
1021+ display_pb = verbose
1022+ else :
1023+ raise ValueError ("Verbose should be True/False or a 0/1 int" )
1024+ else :
1025+ # Check if the verbose attribute is present in the current classifier
1026+ if hasattr (self , "verbose" ):
1027+ display_pb = self .verbose
1028+ # else default to False
1029+ else :
1030+ display_pb = False
1031+
1032+ return display_pb
1033+
9891034 def fit (
9901035 self ,
9911036 x : np .ndarray ,
@@ -1010,15 +1055,7 @@ def fit(
10101055 """
10111056 import tensorflow as tf
10121057
1013- if verbose is None :
1014- display_pb = False
1015- elif isinstance (verbose , int ):
1016- if verbose == 0 :
1017- display_pb = False
1018- else :
1019- display_pb = True
1020- else :
1021- display_pb = verbose
1058+ display_pb = self .process_verbose (verbose )
10221059
10231060 if self ._train_step is None : # pragma: no cover
10241061 if self ._loss_object is None : # pragma: no cover
@@ -1080,15 +1117,7 @@ def fit_generator(
10801117 import tensorflow as tf
10811118 from art .data_generators import TensorFlowV2DataGenerator
10821119
1083- if verbose is None :
1084- display_pb = False
1085- elif isinstance (verbose , int ):
1086- if verbose == 0 :
1087- display_pb = False
1088- else :
1089- display_pb = True
1090- else :
1091- display_pb = verbose
1120+ display_pb = self .process_verbose (verbose )
10921121
10931122 if self ._train_step is None : # pragma: no cover
10941123 if self ._loss_object is None : # pragma: no cover
0 commit comments