Skip to content

Commit 39ab9cd

Browse files
committed
change to verbose, and add support for tf1
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 6c1bc43 commit 39ab9cd

File tree

3 files changed

+39
-37
lines changed

3 files changed

+39
-37
lines changed

art/estimators/classification/pytorch.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def fit( # pylint: disable=W0221
375375
training_mode: bool = True,
376376
drop_last: bool = False,
377377
scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None,
378+
verbose: bool = False,
378379
**kwargs,
379380
) -> None:
380381
"""
@@ -390,14 +391,13 @@ def fit( # pylint: disable=W0221
390391
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
391392
the last batch will be smaller. (default: ``False``)
392393
:param scheduler: Learning rate scheduler to run at the start of every epoch.
393-
:param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
394-
display training progress.
394+
:param verbose: If to display the progress bar information.
395+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
396+
and providing it takes no effect.
395397
"""
396398
import torch
397399
from torch.utils.data import TensorDataset, DataLoader
398400

399-
display_progress_bar = kwargs.get("display_progress_bar", False)
400-
401401
# Set model mode
402402
self._model.train(mode=training_mode)
403403

@@ -419,8 +419,8 @@ def fit( # pylint: disable=W0221
419419
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
420420

421421
# Start training
422-
for _ in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
423-
for x_batch, y_batch in tqdm(dataloader, disable=not display_progress_bar, desc="Batches"):
422+
for _ in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
423+
for x_batch, y_batch in tqdm(dataloader, disable=not verbose, desc="Batches"):
424424
# Move inputs to device
425425
x_batch = x_batch.to(self._device)
426426
y_batch = y_batch.to(self._device)
@@ -456,20 +456,19 @@ def fit( # pylint: disable=W0221
456456
if scheduler is not None:
457457
scheduler.step()
458458

459-
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
459+
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, verbose: bool = False, **kwargs) -> None:
460460
"""
461461
Fit the classifier using the generator that yields batches as specified.
462462
463463
:param generator: Batch generator providing `(x, y)` for each epoch.
464464
:param nb_epochs: Number of epochs to use for training.
465-
:param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
466-
display training progress.
465+
:param verbose: If to display the progress bar information.
466+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
467+
and providing it takes no effect.
467468
"""
468469
import torch
469470
from art.data_generators import PyTorchDataGenerator
470471

471-
display_progress_bar = kwargs.get("display_progress_bar", False)
472-
473472
# Put the model in the training mode
474473
self._model.train()
475474

@@ -490,8 +489,8 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
490489
== (0, 1)
491490
)
492491
):
493-
for _ in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
494-
for i_batch, o_batch in tqdm(generator.iterator, disable=not display_progress_bar, desc="Batches"):
492+
for _ in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
493+
for i_batch, o_batch in tqdm(generator.iterator, disable=not verbose, desc="Batches"):
495494
if isinstance(i_batch, np.ndarray):
496495
i_batch = torch.from_numpy(i_batch).to(self._device)
497496
else:
@@ -500,7 +499,10 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
500499
if isinstance(o_batch, np.ndarray):
501500
o_batch = torch.argmax(torch.from_numpy(o_batch).to(self._device), dim=1)
502501
else:
503-
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
502+
if o_batch.dim() > 1:
503+
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
504+
else:
505+
o_batch = o_batch.to(self._device)
504506

505507
# Zero the parameter gradients
506508
self._optimizer.zero_grad()

art/estimators/classification/tensorflow.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def predict( # pylint: disable=W0221
266266

267267
return predictions
268268

269-
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
269+
def fit(
270+
self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, verbose: bool = False, **kwargs
271+
) -> None:
270272
"""
271273
Fit the classifier on the training set `(x, y)`.
272274
@@ -275,6 +277,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
275277
shape (nb_samples,).
276278
:param batch_size: Size of batches.
277279
:param nb_epochs: Number of epochs to use for training.
280+
:param verbose: If to display the progress bar information.
278281
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
279282
TensorFlow and providing it takes no effect.
280283
"""
@@ -298,12 +301,12 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
298301
ind = np.arange(len(x_preprocessed)).tolist()
299302

300303
# Start training
301-
for _ in range(nb_epochs):
304+
for _ in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
302305
# Shuffle the examples
303306
random.shuffle(ind)
304307

305308
# Train for one epoch
306-
for m in range(num_batch):
309+
for m in tqdm(range(num_batch), disable=not verbose, desc="Batches"):
307310
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
308311
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
309312

@@ -314,13 +317,14 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
314317
# Run train step
315318
self._sess.run(self.train, feed_dict=feed_dict)
316319

317-
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
320+
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, verbose: bool = False, **kwargs) -> None:
318321
"""
319322
Fit the classifier using the generator that yields batches as specified.
320323
321324
:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
322325
training in TensorFlow, it will.
323326
:param nb_epochs: Number of epochs to use for training.
327+
:param verbose: If to display the progress bar information.
324328
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
325329
TensorFlow and providing it takes no effect.
326330
"""
@@ -343,8 +347,8 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
343347
== (0, 1)
344348
)
345349
):
346-
for _ in range(nb_epochs):
347-
for _ in range(int(generator.size / generator.batch_size)): # type: ignore
350+
for _ in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
351+
for _ in tqdm(range(int(generator.size / generator.batch_size)), disable=not verbose, desc="Batches"): # type: ignore
348352
i_batch, o_batch = generator.get_batch()
349353

350354
if self._reduce_labels:
@@ -953,7 +957,9 @@ def _predict_framework(self, x: "tf.Tensor", training_mode: bool = False) -> "tf
953957

954958
return self._model(x_preprocessed, training=training_mode)
955959

956-
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
960+
def fit(
961+
self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, verbose: bool = False, **kwargs
962+
) -> None:
957963
"""
958964
Fit the classifier on the training set `(x, y)`.
959965
@@ -962,14 +968,13 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
962968
shape (nb_samples,).
963969
:param batch_size: Size of batches.
964970
:param nb_epochs: Number of epochs to use for training.
971+
:param verbose: If to display progress bar information.
965972
:param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
966973
"scheduler" which is an optional function that will be called at the end of every
967-
epoch to adjust the learning rate, and "display_progress_bar" to display training progress.
974+
epoch to adjust the learning rate.
968975
"""
969976
import tensorflow as tf
970977

971-
display_progress_bar = kwargs.get("display_progress_bar", False)
972-
973978
if self._train_step is None: # pragma: no cover
974979
if self._loss_object is None: # pragma: no cover
975980
raise TypeError(
@@ -1006,29 +1011,28 @@ def train_step(model, images, labels):
10061011

10071012
train_ds = tf.data.Dataset.from_tensor_slices((x_preprocessed, y_preprocessed)).shuffle(10000).batch(batch_size)
10081013

1009-
for epoch in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
1010-
for images, labels in tqdm(train_ds, disable=not display_progress_bar, desc="Batches"):
1014+
for epoch in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
1015+
for images, labels in tqdm(train_ds, disable=not verbose, desc="Batches"):
10111016
train_step(self.model, images, labels)
10121017

10131018
if scheduler is not None:
10141019
scheduler(epoch)
10151020

1016-
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
1021+
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, verbose: bool = False, **kwargs) -> None:
10171022
"""
10181023
Fit the classifier using the generator that yields batches as specified.
10191024
10201025
:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
10211026
training in TensorFlow, it will.
10221027
:param nb_epochs: Number of epochs to use for training.
1028+
:param verbose: If to display progress bar information
10231029
:param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
10241030
"scheduler" which is an optional function that will be called at the end of every
1025-
epoch to adjust the learning rate, and "display_progress_bar" to display training progress.
1031+
epoch to adjust the learning rate.
10261032
"""
10271033
import tensorflow as tf
10281034
from art.data_generators import TensorFlowV2DataGenerator
10291035

1030-
display_progress_bar = kwargs.get("display_progress_bar", False)
1031-
10321036
if self._train_step is None: # pragma: no cover
10331037
if self._loss_object is None: # pragma: no cover
10341038
raise TypeError(
@@ -1068,8 +1072,8 @@ def train_step(model, images, labels):
10681072
== (0, 1)
10691073
)
10701074
):
1071-
for epoch in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
1072-
for i_batch, o_batch in tqdm(generator.iterator, disable=not display_progress_bar, desc="Batches"):
1075+
for epoch in tqdm(range(nb_epochs), disable=not verbose, desc="Epochs"):
1076+
for i_batch, o_batch in tqdm(generator.iterator, disable=not verbose, desc="Batches"):
10731077
if self._reduce_labels:
10741078
o_batch = tf.math.argmax(o_batch, axis=1)
10751079
train_step(self._model, i_batch, o_batch)

tests/estimators/classification/test_deeplearning_common.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,7 @@ def get_lr(_):
202202
# Test a valid callback
203203
classifier, _ = image_dl_estimator(from_logits=True)
204204

205-
# Keras fit has its own kwarg arguments
206-
if framework in ["kerastf", "keras"]:
207-
kwargs = {"callbacks": [LearningRateScheduler(get_lr)]}
208-
else:
209-
kwargs = {"callbacks": [LearningRateScheduler(get_lr)], "display_progress_bar": True}
205+
kwargs = {"callbacks": [LearningRateScheduler(get_lr)], "verbose": True}
210206
classifier.fit(x_train_mnist, y_train_mnist, batch_size=default_batch_size, nb_epochs=1, **kwargs)
211207

212208
# Test failure for invalid parameters: does not apply to many frameworks which allow arbitrary kwargs

0 commit comments

Comments
 (0)