Skip to content

Commit d8bab78

Browse files
committed
progress bar development
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 49acd32 commit d8bab78

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

.github/workflows/ci-pytorch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ on:
1616
branches:
1717
- main
1818
- dev*
19+
- hf_notebook_updates
1920

2021
# Run scheduled CI flow daily
2122
schedule:

.github/workflows/ci-style-checks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ on:
1616
branches:
1717
- main
1818
- dev*
19+
- hf_notebook_updates
1920

2021
# Run scheduled CI flow daily
2122
schedule:

art/estimators/classification/pytorch.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import os
2727
import time
2828
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
29+
from tqdm.auto import tqdm
2930

3031
import numpy as np
3132
import six
@@ -389,12 +390,14 @@ def fit( # pylint: disable=W0221
389390
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
390391
the last batch will be smaller. (default: ``False``)
391392
:param scheduler: Learning rate scheduler to run at the start of every epoch.
392-
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
393-
and providing it takes no effect.
393+
:param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
394+
display training progress.
394395
"""
395396
import torch
396397
from torch.utils.data import TensorDataset, DataLoader
397398

399+
display_progress_bar = kwargs.get("display_progress_bar", False)
400+
398401
# Set model mode
399402
self._model.train(mode=training_mode)
400403

@@ -416,8 +419,8 @@ def fit( # pylint: disable=W0221
416419
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
417420

418421
# Start training
419-
for _ in range(nb_epochs):
420-
for x_batch, y_batch in dataloader:
422+
for _ in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
423+
for x_batch, y_batch in tqdm(dataloader, disable=not display_progress_bar, desc="Batches"):
421424
# Move inputs to device
422425
x_batch = x_batch.to(self._device)
423426
y_batch = y_batch.to(self._device)
@@ -459,12 +462,14 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
459462
460463
:param generator: Batch generator providing `(x, y)` for each epoch.
461464
:param nb_epochs: Number of epochs to use for training.
462-
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
463-
and providing it takes no effect.
465+
:param kwargs: Dictionary of framework-specific arguments. Currently supports "display_progress_bar" to
466+
display training progress.
464467
"""
465468
import torch
466469
from art.data_generators import PyTorchDataGenerator
467470

471+
display_progress_bar = kwargs.get("display_progress_bar", False)
472+
468473
# Put the model in the training mode
469474
self._model.train()
470475

@@ -485,8 +490,8 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
485490
== (0, 1)
486491
)
487492
):
488-
for _ in range(nb_epochs):
489-
for i_batch, o_batch in generator.iterator:
493+
for _ in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
494+
for i_batch, o_batch in tqdm(generator.iterator, disable=not display_progress_bar, desc="Batches"):
490495
if isinstance(i_batch, np.ndarray):
491496
i_batch = torch.from_numpy(i_batch).to(self._device)
492497
else:
@@ -495,7 +500,10 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
495500
if isinstance(o_batch, np.ndarray):
496501
o_batch = torch.argmax(torch.from_numpy(o_batch).to(self._device), dim=1)
497502
else:
498-
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
503+
if o_batch.dim() > 1:
504+
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
505+
else:
506+
o_batch = o_batch.to(self._device)
499507

500508
# Zero the parameter gradients
501509
self._optimizer.zero_grad()

art/estimators/classification/tensorflow.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import shutil
2828
import time
2929
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
30+
from tqdm.auto import tqdm
3031

3132
import numpy as np
3233
import six
@@ -957,12 +958,14 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
957958
shape (nb_samples,).
958959
:param batch_size: Size of batches.
959960
:param nb_epochs: Number of epochs to use for training.
960-
:param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports
961+
:param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
961962
"scheduler" which is an optional function that will be called at the end of every
962-
epoch to adjust the learning rate.
963+
epoch to adjust the learning rate, and "display_progress_bar" to display training progress.
963964
"""
964965
import tensorflow as tf
965966

967+
display_progress_bar = kwargs.get("display_progress_bar", False)
968+
966969
if self._train_step is None: # pragma: no cover
967970
if self._loss_object is None: # pragma: no cover
968971
raise TypeError(
@@ -999,8 +1002,8 @@ def train_step(model, images, labels):
9991002

10001003
train_ds = tf.data.Dataset.from_tensor_slices((x_preprocessed, y_preprocessed)).shuffle(10000).batch(batch_size)
10011004

1002-
for epoch in range(nb_epochs):
1003-
for images, labels in train_ds:
1005+
for epoch in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
1006+
for images, labels in tqdm(train_ds, disable=not display_progress_bar, desc="Batches"):
10041007
train_step(self.model, images, labels)
10051008

10061009
if scheduler is not None:
@@ -1013,13 +1016,15 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
10131016
:param generator: Batch generator providing `(x, y)` for each epoch. If the generator can be used for native
10141017
training in TensorFlow, it will.
10151018
:param nb_epochs: Number of epochs to use for training.
1016-
:param kwargs: Dictionary of framework-specific arguments. This parameter currently only supports
1019+
:param kwargs: Dictionary of framework-specific arguments. This parameter currently supports
10171020
"scheduler" which is an optional function that will be called at the end of every
1018-
epoch to adjust the learning rate.
1021+
epoch to adjust the learning rate, and "display_progress_bar" to display training progress.
10191022
"""
10201023
import tensorflow as tf
10211024
from art.data_generators import TensorFlowV2DataGenerator
10221025

1026+
display_progress_bar = kwargs.get("display_progress_bar", False)
1027+
10231028
if self._train_step is None: # pragma: no cover
10241029
if self._loss_object is None: # pragma: no cover
10251030
raise TypeError(
@@ -1059,8 +1064,8 @@ def train_step(model, images, labels):
10591064
== (0, 1)
10601065
)
10611066
):
1062-
for epoch in range(nb_epochs):
1063-
for i_batch, o_batch in generator.iterator:
1067+
for epoch in tqdm(range(nb_epochs), disable=not display_progress_bar, desc="Epochs"):
1068+
for i_batch, o_batch in tqdm(generator.iterator, disable=not display_progress_bar, desc="Batches"):
10641069
if self._reduce_labels:
10651070
o_batch = tf.math.argmax(o_batch, axis=1)
10661071
train_step(self._model, i_batch, o_batch)

0 commit comments

Comments
 (0)