Skip to content

Commit 3c6458d

Browse files
committed
split verbosity processing into separate method.
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 2a3290a commit 3c6458d

File tree

2 files changed

+100
-58
lines changed

2 files changed

+100
-58
lines changed

art/estimators/classification/pytorch.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,35 @@ def _predict_framework(
366366

367367
return output, y_preprocessed
368368

369+
def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
370+
"""
371+
Function to unify the various ways implemented in ART of displaying progress bars
372+
into a single True/False output.
373+
374+
:param verbose: If to display the progress bar information.
375+
:return: True/False if to display the progress bars.
376+
"""
377+
378+
if verbose is not None:
379+
if isinstance(verbose, int):
380+
if verbose == 0:
381+
display_pb = False
382+
else:
383+
display_pb = True
384+
elif isinstance(verbose, bool):
385+
display_pb = verbose
386+
else:
387+
raise ValueError("Verbose should be True/False or a 0/1 int")
388+
else:
389+
# Check if the verbose attribute is present in the current classifier
390+
if hasattr(self, "verbose"):
391+
display_pb = self.verbose
392+
# else default to False
393+
else:
394+
display_pb = False
395+
396+
return display_pb
397+
369398
def fit( # pylint: disable=W0221
370399
self,
371400
x: np.ndarray,
@@ -398,15 +427,7 @@ def fit( # pylint: disable=W0221
398427
import torch
399428
from torch.utils.data import TensorDataset, DataLoader
400429

401-
if verbose is None:
402-
display_pb = False
403-
elif isinstance(verbose, int):
404-
if verbose == 0:
405-
display_pb = False
406-
else:
407-
display_pb = True
408-
else:
409-
display_pb = verbose
430+
display_pb = self.process_verbose(verbose)
410431

411432
# Set model mode
412433
self._model.train(mode=training_mode)
@@ -481,15 +502,7 @@ def fit_generator(
481502
import torch
482503
from art.data_generators import PyTorchDataGenerator
483504

484-
if verbose is None:
485-
display_pb = False
486-
elif isinstance(verbose, int):
487-
if verbose == 0:
488-
display_pb = False
489-
else:
490-
display_pb = True
491-
else:
492-
display_pb = verbose
505+
display_pb = self.process_verbose(verbose)
493506

494507
# Put the model in the training mode
495508
self._model.train()

art/estimators/classification/tensorflow.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,35 @@ def predict( # pylint: disable=W0221
266266

267267
return predictions
268268

269-
def fit(
269+
def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
270+
"""
271+
Function to unify the various ways implemented in ART of displaying progress bars
272+
into a single True/False output.
273+
:param verbose: If to display the progress bar information.
274+
:return: True/False if to display the progress bars.
275+
"""
276+
277+
if verbose is not None:
278+
if isinstance(verbose, int):
279+
if verbose == 0:
280+
display_pb = False
281+
else:
282+
display_pb = True
283+
elif isinstance(verbose, bool):
284+
display_pb = verbose
285+
else:
286+
raise ValueError("Verbose should be True/False or a 0/1 int")
287+
else:
288+
# Check if the verbose attribute is present in the current classifier
289+
if hasattr(self, "verbose"):
290+
display_pb = self.verbose
291+
# else default to False
292+
else:
293+
display_pb = False
294+
295+
return display_pb
296+
297+
def fit( # pylint: disable=W0221
270298
self,
271299
x: np.ndarray,
272300
y: np.ndarray,
@@ -290,15 +318,7 @@ def fit(
290318
if self.learning is not None:
291319
self.feed_dict[self.learning] = True
292320

293-
if verbose is None:
294-
display_pb = False
295-
elif isinstance(verbose, int):
296-
if verbose == 0:
297-
display_pb = False
298-
else:
299-
display_pb = True
300-
else:
301-
display_pb = verbose
321+
display_pb = self.process_verbose(verbose)
302322

303323
# Check if train and output_ph available
304324
if self.train is None or self.labels_ph is None: # pragma: no cover
@@ -333,7 +353,7 @@ def fit(
333353
# Run train step
334354
self._sess.run(self.train, feed_dict=feed_dict)
335355

336-
def fit_generator(
356+
def fit_generator( # pylint: disable=W0221
337357
self, generator: "DataGenerator", nb_epochs: int = 20, verbose: Optional[Union[bool, int]] = None, **kwargs
338358
) -> None:
339359
"""
@@ -348,15 +368,7 @@ def fit_generator(
348368
"""
349369
from art.data_generators import TensorFlowDataGenerator
350370

351-
if verbose is None:
352-
display_pb = False
353-
elif isinstance(verbose, int):
354-
if verbose == 0:
355-
display_pb = False
356-
else:
357-
display_pb = True
358-
else:
359-
display_pb = verbose
371+
display_pb = self.process_verbose(verbose)
360372

361373
if self.learning is not None:
362374
self.feed_dict[self.learning] = True
@@ -376,8 +388,13 @@ def fit_generator(
376388
)
377389
):
378390
for _ in tqdm(range(nb_epochs), disable=not display_pb, desc="Epochs"):
379-
num_bathces = int(generator.size / generator.batch_size)
380-
for _ in tqdm(range(num_bathces), disable=not display_pb, desc="Batches"): # type: ignore
391+
gen_size = generator.size
392+
if isinstance(gen_size, int):
393+
num_batchcs = int(gen_size / generator.batch_size)
394+
else:
395+
raise ValueError("Number of batches could not be determined from the generator")
396+
397+
for _ in tqdm(range(num_batchcs), disable=not display_pb, desc="Batches"):
381398
i_batch, o_batch = generator.get_batch()
382399

383400
if self._reduce_labels:
@@ -986,6 +1003,34 @@ def _predict_framework(self, x: "tf.Tensor", training_mode: bool = False) -> "tf
9861003

9871004
return self._model(x_preprocessed, training=training_mode)
9881005

1006+
def process_verbose(self, verbose: Optional[Union[bool, int]] = None) -> bool:
1007+
"""
1008+
Function to unify the various ways implemented in ART of displaying progress bars
1009+
into a single True/False output.
1010+
:param verbose: If to display the progress bar information.
1011+
:return: True/False if to display the progress bars.
1012+
"""
1013+
1014+
if verbose is not None:
1015+
if isinstance(verbose, int):
1016+
if verbose == 0:
1017+
display_pb = False
1018+
else:
1019+
display_pb = True
1020+
elif isinstance(verbose, bool):
1021+
display_pb = verbose
1022+
else:
1023+
raise ValueError("Verbose should be True/False or a 0/1 int")
1024+
else:
1025+
# Check if the verbose attribute is present in the current classifier
1026+
if hasattr(self, "verbose"):
1027+
display_pb = self.verbose
1028+
# else default to False
1029+
else:
1030+
display_pb = False
1031+
1032+
return display_pb
1033+
9891034
def fit(
9901035
self,
9911036
x: np.ndarray,
@@ -1010,15 +1055,7 @@ def fit(
10101055
"""
10111056
import tensorflow as tf
10121057

1013-
if verbose is None:
1014-
display_pb = False
1015-
elif isinstance(verbose, int):
1016-
if verbose == 0:
1017-
display_pb = False
1018-
else:
1019-
display_pb = True
1020-
else:
1021-
display_pb = verbose
1058+
display_pb = self.process_verbose(verbose)
10221059

10231060
if self._train_step is None: # pragma: no cover
10241061
if self._loss_object is None: # pragma: no cover
@@ -1080,15 +1117,7 @@ def fit_generator(
10801117
import tensorflow as tf
10811118
from art.data_generators import TensorFlowV2DataGenerator
10821119

1083-
if verbose is None:
1084-
display_pb = False
1085-
elif isinstance(verbose, int):
1086-
if verbose == 0:
1087-
display_pb = False
1088-
else:
1089-
display_pb = True
1090-
else:
1091-
display_pb = verbose
1120+
display_pb = self.process_verbose(verbose)
10921121

10931122
if self._train_step is None: # pragma: no cover
10941123
if self._loss_object is None: # pragma: no cover

0 commit comments

Comments
 (0)