@@ -109,8 +109,8 @@ def __init__(
109109 classifier : "IBP_CERTIFIER_TYPE" ,
110110 nb_epochs : Optional [int ] = 20 ,
111111 bound : float = 0.1 ,
112- loss_weighting : float = 0.1 ,
113112 batch_size : int = 32 ,
113+ loss_weighting : Optional [int ] = None ,
114114 use_certification_schedule : bool = True ,
115115 certification_schedule : Optional [Any ] = None ,
116116 use_loss_weighting_schedule : bool = True ,
@@ -133,9 +133,9 @@ def __init__(
133133 * *max_iter*: The maximum number of iterations.
134134 * *batch_size*: Size of the batch on which adversarial samples are generated.
135135 * *num_random_init*: Number of random initialisations within the epsilon ball.
136+ :param loss_weighting: Weighting factor for the certified loss.
136137 :param bound: The perturbation range for the interval. If the default certification schedule is used
137138 will be the upper limit.
138- :param loss_weighting: Weighting factor for the certified loss.
139139 :param nb_epochs: Number of training epochs.
140140 :param use_certification_schedule: If to use a training schedule for the certification radius.
141141 :param certification_schedule: Schedule for gradually increasing the certification radius. Empirical studies
@@ -152,6 +152,14 @@ def __init__(
152152 "art.estimators.certification.interval.pytorch.PyTorchIBPClassifier"
153153 )
154154
155+ if not use_loss_weighting_schedule and loss_weighting is None :
156+ raise ValueError (
157+ "If a loss weighting schedule is not used then a value for loss_weighting should be supplied."
158+ )
159+
160+ if use_loss_weighting_schedule and loss_weighting is not None :
161+ raise ValueError ("Using a loss weighting schedule is incompatible with a fixed loss_weighting." )
162+
155163 super ().__init__ (classifier = classifier )
156164 self .classifier : "IBP_CERTIFIER_TYPE"
157165 self .pgd_params : "PGDParamDict"
@@ -300,8 +308,10 @@ def fit( # pylint: disable=W0221
300308 self .loss_weighting_schedule = self .initialise_default_scheduler (
301309 initial_val = 0.0 , final_val = 0.5 , epochs = epochs
302310 )
311+ elif self .loss_weighting is not None :
312+ loss_weighting_k = self .loss_weighting
303313 else :
304- loss_weighting_k = 0.1
314+ raise ValueError ( "Unable to determine loss weighting." )
305315
306316 for _ in tqdm (range (epochs )):
307317 if self .use_certification_schedule and self .certification_schedule is not None :
0 commit comments