File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -158,9 +158,7 @@ def __init__(
158158 )
159159
160160 if use_loss_weighting_schedule and loss_weighting is not None :
161- raise ValueError (
162- "Using a loss weighting schedule is incompatible with a fixed loss_weighting."
163- )
161+ raise ValueError ("Using a loss weighting schedule is incompatible with a fixed loss_weighting." )
164162
165163 super ().__init__ (classifier = classifier )
166164 self .classifier : "IBP_CERTIFIER_TYPE"
@@ -310,8 +308,10 @@ def fit( # pylint: disable=W0221
310308 self .loss_weighting_schedule = self .initialise_default_scheduler (
311309 initial_val = 0.0 , final_val = 0.5 , epochs = epochs
312310 )
313- else :
311+ elif self . loss_weighting is not None :
314312 loss_weighting_k = self .loss_weighting
313+ else :
314+ raise ValueError ("Unable to determine loss weighting." )
315315
316316 for _ in tqdm (range (epochs )):
317317 if self .use_certification_schedule and self .certification_schedule is not None :
You can’t perform that action at this time.
0 commit comments