Skip to content

Commit e3f9f33

Browse files
authored
Merge pull request #1612 from Trusted-AI/development_issue_1609
Add nb_epochs and batch_size to AdversarialTrainerMadryPGD.fit
2 parents c25abb5 + b630aae commit e3f9f33

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

.github/workflows/ci-style-checks.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ jobs:
4141
pip install pluggy==0.13.1
4242
pip install tensorflow==2.7.0
4343
pip install keras==2.7.0
44-
python -m pip install types-six
45-
python -m pip install types-PyYAML
46-
python3 -m pip install types-setuptools
44+
pip install types-six
45+
pip install types-PyYAML
46+
pip install types-setuptools
4747
pip install click==8.0.2
4848
pip list
4949
- name: pycodestyle

art/defences/trainer/adversarial_trainer_madry_pgd.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class AdversarialTrainerMadryPGD(Trainer):
5454
def __init__(
5555
self,
5656
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
57-
nb_epochs: int = 391,
58-
batch_size: int = 128,
57+
nb_epochs: Optional[int] = 391,
58+
batch_size: Optional[int] = 128,
5959
eps: Union[int, float] = 8,
6060
eps_step: Union[int, float] = 2,
6161
max_iter: int = 7,
@@ -91,18 +91,41 @@ def __init__(
9191
self.trainer = AdversarialTrainer(classifier, self.attack, ratio=1.0) # type: ignore
9292

9393
def fit( # pylint: disable=W0221
94-
self, x: np.ndarray, y: np.ndarray, validation_data: Optional[np.ndarray] = None, **kwargs
94+
self,
95+
x: np.ndarray,
96+
y: np.ndarray,
97+
validation_data: Optional[np.ndarray] = None,
98+
batch_size: Optional[int] = None,
99+
nb_epochs: Optional[int] = None,
100+
**kwargs
95101
) -> None:
96102
"""
97103
Train a model adversarially. See class documentation for more information on the exact procedure.
98104
99105
:param x: Training data.
100106
:param y: Labels for the training data.
101107
:param validation_data: Validation data.
108+
:param batch_size: Size of batches. Overwrites batch_size defined in __init__ if not None.
109+
:param nb_epochs: Number of epochs to use for trainings. Overwrites nb_epochs defined in __init__ if not None.
102110
:param kwargs: Dictionary of framework-specific arguments.
103111
"""
112+
batch_size_fit: int
113+
if batch_size is not None:
114+
batch_size_fit = batch_size
115+
elif self.batch_size is not None:
116+
batch_size_fit = self.batch_size
117+
else:
118+
raise ValueError("Please provide value for `batch_size`.")
119+
120+
if nb_epochs is not None:
121+
nb_epochs_fit: int = nb_epochs
122+
elif self.nb_epochs is not None:
123+
nb_epochs_fit = self.nb_epochs
124+
else:
125+
raise ValueError("Please provide value for `nb_epochs`.")
126+
104127
self.trainer.fit(
105-
x, y, validation_data=validation_data, nb_epochs=self.nb_epochs, batch_size=self.batch_size, **kwargs
128+
x, y, validation_data=validation_data, nb_epochs=nb_epochs_fit, batch_size=batch_size_fit, **kwargs
106129
)
107130

108131
def get_classifier(self) -> "CLASSIFIER_LOSS_GRADIENTS_TYPE":

0 commit comments

Comments
 (0)