Skip to content

Commit e6493a9

Browse files
committed
Improve prediction performance
Signed-off-by: Beat Buesser <[email protected]>
1 parent 7341bb3 commit e6493a9

File tree

3 files changed

+36
-70
lines changed

3 files changed

+36
-70
lines changed

art/defences/detector/poison/activation_defence.py

Lines changed: 35 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,15 @@
2626
"""
2727
from __future__ import absolute_import, division, print_function, unicode_literals, annotations
2828

29+
import copy
2930
import logging
3031
import os
31-
import pickle
32-
import time
3332
from typing import Any, TYPE_CHECKING
3433

3534
import numpy as np
3635

3736
from sklearn.cluster import KMeans, MiniBatchKMeans
3837

39-
from art import config
4038
from art.data_generators import DataGenerator
4139
from art.defences.detector.poison.clustering_analyzer import ClusteringAnalyzer
4240
from art.defences.detector.poison.ground_truth_evaluator import GroundTruthEvaluator
@@ -468,12 +466,25 @@ def relabel_poison_ground_truth(
468466
x_train, x_test = x[:n_train], x[n_train:]
469467
y_train, y_test = y_fix[:n_train], y_fix[n_train:]
470468

471-
filename = "original_classifier" + str(time.time()) + ".p"
472-
ActivationDefence._pickle_classifier(classifier, filename)
469+
from tensorflow.keras.models import clone_model
470+
471+
model = classifier._model
472+
forward_pass = classifier._forward_pass
473+
classifier._model = None
474+
classifier._forward_pass = None
475+
476+
curr_classifier = copy.deepcopy(classifier)
477+
curr_model = clone_model(model)
478+
curr_model.set_weights(model.get_weights())
479+
curr_classifier._model = curr_model
480+
curr_classifier._forward_pass = forward_pass
481+
482+
classifier._model = model
483+
classifier._forward_pass = forward_pass
473484

474485
# Now train using y_fix:
475486
improve_factor, _ = train_remove_backdoor(
476-
classifier,
487+
curr_classifier,
477488
x_train,
478489
y_train,
479490
x_test,
@@ -485,11 +496,9 @@ def relabel_poison_ground_truth(
485496

486497
# Only update classifier if there was an improvement:
487498
if improve_factor < 0:
488-
classifier = ActivationDefence._unpickle_classifier(filename)
489499
return 0, classifier
490500

491-
ActivationDefence._remove_pickle(filename)
492-
return improve_factor, classifier
501+
return improve_factor, curr_classifier
493502

494503
@staticmethod
495504
def relabel_poison_cross_validation(
@@ -514,23 +523,35 @@ def relabel_poison_cross_validation(
514523
:param batch_epochs: Number of epochs to be trained before checking current state of model.
515524
:return: (improve_factor, classifier)
516525
"""
517-
518526
from sklearn.model_selection import KFold
519527

520528
# Train using cross validation
521529
k_fold = KFold(n_splits=n_splits)
522530
KFold(n_splits=n_splits, random_state=None, shuffle=True)
523531

524-
filename = "original_classifier" + str(time.time()) + ".p"
525-
ActivationDefence._pickle_classifier(classifier, filename)
526532
curr_improvement = 0
527533

528534
for train_index, test_index in k_fold.split(x):
529535
# Obtain partition:
530536
x_train, x_test = x[train_index], x[test_index]
531537
y_train, y_test = y_fix[train_index], y_fix[test_index]
532-
# Unpickle original model:
533-
curr_classifier = ActivationDefence._unpickle_classifier(filename)
538+
# Copy original model:
539+
540+
from tensorflow.keras.models import clone_model
541+
542+
model = classifier._model
543+
forward_pass = classifier._forward_pass
544+
classifier._model = None
545+
classifier._forward_pass = None
546+
547+
curr_classifier = copy.deepcopy(classifier)
548+
curr_model = clone_model(model)
549+
curr_model.set_weights(model.get_weights())
550+
curr_classifier._model = curr_model
551+
curr_classifier._forward_pass = forward_pass
552+
553+
classifier._model = model
554+
classifier._forward_pass = forward_pass
534555

535556
new_improvement, fixed_classifier = train_remove_backdoor(
536557
curr_classifier,
@@ -547,50 +568,8 @@ def relabel_poison_cross_validation(
547568
classifier = fixed_classifier
548569
logger.info("Selected as best model so far: %s", curr_improvement)
549570

550-
ActivationDefence._remove_pickle(filename)
551571
return curr_improvement, classifier
552572

553-
@staticmethod
554-
def _pickle_classifier(classifier: "CLASSIFIER_NEURALNETWORK_TYPE", file_name: str) -> None:
555-
"""
556-
Pickles the self.classifier and stores it using the provided file_name in folder `art.config.ART_DATA_PATH`.
557-
558-
:param classifier: Classifier to be pickled.
559-
:param file_name: Name of the file where the classifier will be pickled.
560-
"""
561-
full_path = os.path.join(config.ART_DATA_PATH, file_name)
562-
folder = os.path.split(full_path)[0]
563-
if not os.path.exists(folder):
564-
os.makedirs(folder)
565-
566-
with open(full_path, "wb") as f_classifier:
567-
pickle.dump(classifier, f_classifier)
568-
569-
@staticmethod
570-
def _unpickle_classifier(file_name: str) -> "CLASSIFIER_NEURALNETWORK_TYPE":
571-
"""
572-
Unpickles classifier using the filename provided. Function assumes that the pickle is in
573-
`art.config.ART_DATA_PATH`.
574-
575-
:param file_name: Path of the pickled classifier relative to `ART_DATA_PATH`.
576-
:return: The loaded classifier.
577-
"""
578-
full_path = os.path.join(config.ART_DATA_PATH, file_name)
579-
logger.info("Loading classifier from %s", full_path)
580-
with open(full_path, "rb") as f_classifier:
581-
loaded_classifier = pickle.load(f_classifier)
582-
return loaded_classifier
583-
584-
@staticmethod
585-
def _remove_pickle(file_name: str) -> None:
586-
"""
587-
Erases the pickle with the provided file name.
588-
589-
:param file_name: File name without directory.
590-
"""
591-
full_path = os.path.join(config.ART_DATA_PATH, file_name)
592-
os.remove(full_path)
593-
594573
def visualize_clusters(
595574
self, x_raw: np.ndarray, save: bool = True, folder: str = ".", **kwargs
596575
) -> list[list[np.ndarray]]:

tests/attacks/evasion/test_sign_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def fix_get_mnist_subset_large(get_mnist_dataset):
5454
def test_tabular(art_warning, tabular_dl_estimator, framework, get_iris_dataset, clipped_classifier, targeted):
5555
try:
5656
classifier = tabular_dl_estimator(clipped=clipped_classifier)
57-
attack = SignOPTAttack(classifier, targeted=targeted, num_trial=10, max_iter=100, query_limit=40, verbose=True)
57+
attack = SignOPTAttack(classifier, targeted=targeted, num_trial=10, max_iter=10, query_limit=40, verbose=True)
5858
if targeted:
5959
backend_targeted_tabular(attack, get_iris_dataset)
6060
else:

tests/defences/detector/poison/test_activation_defence.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -256,19 +256,6 @@ def test_plot_clusters(self):
256256
self.defence.plot_clusters(save=False)
257257
self.defence_gen.plot_clusters(save=False)
258258

259-
def test_pickle(self):
260-
261-
# Test pickle and unpickle:
262-
filename = "test_pickle.h5"
263-
ActivationDefence._pickle_classifier(self.classifier, filename)
264-
loaded = ActivationDefence._unpickle_classifier(filename)
265-
266-
np.testing.assert_equal(self.classifier._clip_values, loaded._clip_values)
267-
self.assertEqual(self.classifier._channels_first, loaded._channels_first)
268-
self.assertEqual(self.classifier._use_logits, loaded._use_logits)
269-
270-
ActivationDefence._remove_pickle(filename)
271-
272259
def test_fix_relabel_poison(self):
273260
(x_train, y_train), (_, _), (_, _) = self.mnist
274261
x_poison = x_train[:100]

0 commit comments

Comments
 (0)