Skip to content

Commit 3508b86

Browse files
committed
Improve prediction performance
Signed-off-by: Beat Buesser <[email protected]>
1 parent e6493a9 commit 3508b86

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

art/defences/detector/poison/activation_defence.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -469,18 +469,18 @@ def relabel_poison_ground_truth(
469469
from tensorflow.keras.models import clone_model
470470

471471
model = classifier._model
472-
forward_pass = classifier._forward_pass
472+
forward_pass = classifier._forward_pass # type: ignore
473473
classifier._model = None
474-
classifier._forward_pass = None
474+
classifier._forward_pass = None # type: ignore
475475

476476
curr_classifier = copy.deepcopy(classifier)
477477
curr_model = clone_model(model)
478478
curr_model.set_weights(model.get_weights())
479479
curr_classifier._model = curr_model
480-
curr_classifier._forward_pass = forward_pass
480+
curr_classifier._forward_pass = forward_pass # type: ignore
481481

482482
classifier._model = model
483-
classifier._forward_pass = forward_pass
483+
classifier._forward_pass = forward_pass # type: ignore
484484

485485
# Now train using y_fix:
486486
improve_factor, _ = train_remove_backdoor(
@@ -540,18 +540,18 @@ def relabel_poison_cross_validation(
540540
from tensorflow.keras.models import clone_model
541541

542542
model = classifier._model
543-
forward_pass = classifier._forward_pass
543+
forward_pass = classifier._forward_pass # type: ignore
544544
classifier._model = None
545-
classifier._forward_pass = None
545+
classifier._forward_pass = None # type: ignore
546546

547547
curr_classifier = copy.deepcopy(classifier)
548548
curr_model = clone_model(model)
549549
curr_model.set_weights(model.get_weights())
550550
curr_classifier._model = curr_model
551-
curr_classifier._forward_pass = forward_pass
551+
curr_classifier._forward_pass = forward_pass # type: ignore
552552

553553
classifier._model = model
554-
classifier._forward_pass = forward_pass
554+
classifier._forward_pass = forward_pass # type: ignore
555555

556556
new_improvement, fixed_classifier = train_remove_backdoor(
557557
curr_classifier,

0 commit comments

Comments
 (0)