Skip to content

Commit e87ed91

Browse files
committed
formatting fixes and adding additional input checking
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 75a9d98 commit e87ed91

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

art/defences/trainer/ibp_certified_trainer_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def __init__(
158158
)
159159

160160
if use_loss_weighting_schedule and loss_weighting is not None:
161-
raise ValueError(
162-
"Using a loss weighting schedule is incompatible with a fixed loss_weighting."
163-
)
161+
raise ValueError("Using a loss weighting schedule is incompatible with a fixed loss_weighting.")
164162

165163
super().__init__(classifier=classifier)
166164
self.classifier: "IBP_CERTIFIER_TYPE"
@@ -310,8 +308,10 @@ def fit( # pylint: disable=W0221
310308
self.loss_weighting_schedule = self.initialise_default_scheduler(
311309
initial_val=0.0, final_val=0.5, epochs=epochs
312310
)
313-
else:
311+
elif self.loss_weighting is not None:
314312
loss_weighting_k = self.loss_weighting
313+
else:
314+
raise ValueError("Unable to determine loss weighting.")
315315

316316
for _ in tqdm(range(epochs)):
317317
if self.use_certification_schedule and self.certification_schedule is not None:

0 commit comments

Comments
 (0)