Skip to content

Commit cb01479

Browse files
authored
Merge branch 'dev_1.14.1' into development_issue_2116
2 parents 731c1d3 + 63f3501 commit cb01479

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

art/defences/trainer/ibp_certified_trainer_pytorch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(
109109
classifier: "IBP_CERTIFIER_TYPE",
110110
nb_epochs: Optional[int] = 20,
111111
bound: float = 0.1,
112-
loss_weighting: float = 0.1,
113112
batch_size: int = 32,
113+
loss_weighting: Optional[int] = None,
114114
use_certification_schedule: bool = True,
115115
certification_schedule: Optional[Any] = None,
116116
use_loss_weighting_schedule: bool = True,
@@ -133,9 +133,9 @@ def __init__(
133133
* *max_iter*: The maximum number of iterations.
134134
* *batch_size*: Size of the batch on which adversarial samples are generated.
135135
* *num_random_init*: Number of random initialisations within the epsilon ball.
136+
:param loss_weighting: Weighting factor for the certified loss.
136137
:param bound: The perturbation range for the interval. If the default certification schedule is used
137138
will be the upper limit.
138-
:param loss_weighting: Weighting factor for the certified loss.
139139
:param nb_epochs: Number of training epochs.
140140
:param use_certification_schedule: If to use a training schedule for the certification radius.
141141
:param certification_schedule: Schedule for gradually increasing the certification radius. Empirical studies
@@ -152,6 +152,14 @@ def __init__(
152152
"art.estimators.certification.interval.pytorch.PyTorchIBPClassifier"
153153
)
154154

155+
if not use_loss_weighting_schedule and loss_weighting is None:
156+
raise ValueError(
157+
"If a loss weighting schedule is not used then a value for loss_weighting should be supplied."
158+
)
159+
160+
if use_loss_weighting_schedule and loss_weighting is not None:
161+
raise ValueError("Using a loss weighting schedule is incompatible with a fixed loss_weighting.")
162+
155163
super().__init__(classifier=classifier)
156164
self.classifier: "IBP_CERTIFIER_TYPE"
157165
self.pgd_params: "PGDParamDict"
@@ -300,8 +308,10 @@ def fit( # pylint: disable=W0221
300308
self.loss_weighting_schedule = self.initialise_default_scheduler(
301309
initial_val=0.0, final_val=0.5, epochs=epochs
302310
)
311+
elif self.loss_weighting is not None:
312+
loss_weighting_k = self.loss_weighting
303313
else:
304-
loss_weighting_k = 0.1
314+
raise ValueError("Unable to determine loss weighting.")
305315

306316
for _ in tqdm(range(epochs)):
307317
if self.use_certification_schedule and self.certification_schedule is not None:

0 commit comments

Comments
 (0)