Skip to content

Commit 36f2c3b

Browse files
author
Beat Buesser
committed
Update typing
Signed-off-by: Beat Buesser <[email protected]>
1 parent 4f580a3 commit 36f2c3b

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

art/defences/trainer/adversarial_trainer_madry_pgd.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,20 @@ def fit( # pylint: disable=W0221
109109
:param nb_epochs: Number of epochs to use for trainings. Overwrites nb_epochs defined in __init__ if not None.
110110
:param kwargs: Dictionary of framework-specific arguments.
111111
"""
112+
batch_size_fit: int
112113
if batch_size is not None:
113114
batch_size_fit = batch_size
114-
else:
115+
elif self.batch_size is not None:
115116
batch_size_fit = self.batch_size
117+
else:
118+
raise ValueError("Please provide value for `batch_size`.")
116119

117120
if nb_epochs is not None:
118-
nb_epochs_fit = nb_epochs
119-
else:
121+
nb_epochs_fit: int = nb_epochs
122+
elif self.nb_epochs is not None:
120123
nb_epochs_fit = self.nb_epochs
124+
else:
125+
raise ValueError("Please provide value for `nb_epochs`.")
121126

122127
self.trainer.fit(
123128
x, y, validation_data=validation_data, nb_epochs=nb_epochs_fit, batch_size=batch_size_fit, **kwargs

0 commit comments

Comments
 (0)