Skip to content

Commit 2a3290a

Browse files
committed
mypy fixes
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 5fb24d5 commit 2a3290a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,10 @@ def fit( # pylint: disable=W0221
438438
training_mode: bool = True,
439439
drop_last: bool = False,
440440
scheduler: Optional[Any] = None,
441+
verbose: Optional[Union[bool, int]] = None,
441442
update_batchnorm: bool = True,
442443
batchnorm_update_epochs: int = 1,
443444
transform: Optional["torchvision.transforms.transforms.Compose"] = None,
444-
verbose: Optional[Union[bool, int]] = None,
445445
**kwargs,
446446
) -> None:
447447
"""
@@ -457,13 +457,13 @@ def fit( # pylint: disable=W0221
457457
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
458458
the last batch will be smaller. (default: ``False``)
459459
:param scheduler: Learning rate scheduler to run at the start of every epoch.
460+
:param verbose: if to display training progress bars
460461
:param update_batchnorm: ViT specific argument.
461462
If to run the training data through the model to update any batch norm statistics prior
462463
to training. Useful on small datasets when using pre-trained ViTs.
463464
:param batchnorm_update_epochs: ViT specific argument. How many times to forward pass over the training data
464465
to pre-adjust the batchnorm statistics.
465466
:param transform: ViT specific argument. Torchvision compose of relevant augmentation transformations to apply.
466-
:param verbose: if to display training progress bars
467467
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
468468
and providing it takes no effect.
469469
"""

art/estimators/classification/tensorflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ def fit_generator(
376376
)
377377
):
378378
for _ in tqdm(range(nb_epochs), disable=not display_pb, desc="Epochs"):
379-
for _ in tqdm(range(int(generator.size / generator.batch_size)), disable=not display_pb, desc="Batches"): # type: ignore
379+
num_bathces = int(generator.size / generator.batch_size)
380+
for _ in tqdm(range(num_bathces), disable=not display_pb, desc="Batches"): # type: ignore
380381
i_batch, o_batch = generator.get_batch()
381382

382383
if self._reduce_labels:

0 commit comments

Comments
 (0)