26
26
"""
27
27
from __future__ import absolute_import , division , print_function , unicode_literals , annotations
28
28
29
+ import copy
29
30
import logging
30
31
import os
31
- import pickle
32
- import time
33
32
from typing import Any , TYPE_CHECKING
34
33
35
34
import numpy as np
36
35
37
36
from sklearn .cluster import KMeans , MiniBatchKMeans
38
37
39
- from art import config
40
38
from art .data_generators import DataGenerator
41
39
from art .defences .detector .poison .clustering_analyzer import ClusteringAnalyzer
42
40
from art .defences .detector .poison .ground_truth_evaluator import GroundTruthEvaluator
@@ -468,12 +466,25 @@ def relabel_poison_ground_truth(
468
466
x_train , x_test = x [:n_train ], x [n_train :]
469
467
y_train , y_test = y_fix [:n_train ], y_fix [n_train :]
470
468
471
- filename = "original_classifier" + str (time .time ()) + ".p"
472
- ActivationDefence ._pickle_classifier (classifier , filename )
469
+ from tensorflow .keras .models import clone_model
470
+
471
+ model = classifier ._model
472
+ forward_pass = classifier ._forward_pass
473
+ classifier ._model = None
474
+ classifier ._forward_pass = None
475
+
476
+ curr_classifier = copy .deepcopy (classifier )
477
+ curr_model = clone_model (model )
478
+ curr_model .set_weights (model .get_weights ())
479
+ curr_classifier ._model = curr_model
480
+ curr_classifier ._forward_pass = forward_pass
481
+
482
+ classifier ._model = model
483
+ classifier ._forward_pass = forward_pass
473
484
474
485
# Now train using y_fix:
475
486
improve_factor , _ = train_remove_backdoor (
476
- classifier ,
487
+ curr_classifier ,
477
488
x_train ,
478
489
y_train ,
479
490
x_test ,
@@ -485,11 +496,9 @@ def relabel_poison_ground_truth(
485
496
486
497
# Only update classifier if there was an improvement:
487
498
if improve_factor < 0 :
488
- classifier = ActivationDefence ._unpickle_classifier (filename )
489
499
return 0 , classifier
490
500
491
- ActivationDefence ._remove_pickle (filename )
492
- return improve_factor , classifier
501
+ return improve_factor , curr_classifier
493
502
494
503
@staticmethod
495
504
def relabel_poison_cross_validation (
@@ -514,23 +523,35 @@ def relabel_poison_cross_validation(
514
523
:param batch_epochs: Number of epochs to be trained before checking current state of model.
515
524
:return: (improve_factor, classifier)
516
525
"""
517
-
518
526
from sklearn .model_selection import KFold
519
527
520
528
# Train using cross validation
521
529
k_fold = KFold (n_splits = n_splits )
522
530
KFold (n_splits = n_splits , random_state = None , shuffle = True )
523
531
524
- filename = "original_classifier" + str (time .time ()) + ".p"
525
- ActivationDefence ._pickle_classifier (classifier , filename )
526
532
curr_improvement = 0
527
533
528
534
for train_index , test_index in k_fold .split (x ):
529
535
# Obtain partition:
530
536
x_train , x_test = x [train_index ], x [test_index ]
531
537
y_train , y_test = y_fix [train_index ], y_fix [test_index ]
532
- # Unpickle original model:
533
- curr_classifier = ActivationDefence ._unpickle_classifier (filename )
538
+ # Copy original model:
539
+
540
+ from tensorflow .keras .models import clone_model
541
+
542
+ model = classifier ._model
543
+ forward_pass = classifier ._forward_pass
544
+ classifier ._model = None
545
+ classifier ._forward_pass = None
546
+
547
+ curr_classifier = copy .deepcopy (classifier )
548
+ curr_model = clone_model (model )
549
+ curr_model .set_weights (model .get_weights ())
550
+ curr_classifier ._model = curr_model
551
+ curr_classifier ._forward_pass = forward_pass
552
+
553
+ classifier ._model = model
554
+ classifier ._forward_pass = forward_pass
534
555
535
556
new_improvement , fixed_classifier = train_remove_backdoor (
536
557
curr_classifier ,
@@ -547,50 +568,8 @@ def relabel_poison_cross_validation(
547
568
classifier = fixed_classifier
548
569
logger .info ("Selected as best model so far: %s" , curr_improvement )
549
570
550
- ActivationDefence ._remove_pickle (filename )
551
571
return curr_improvement , classifier
552
572
553
- @staticmethod
554
- def _pickle_classifier (classifier : "CLASSIFIER_NEURALNETWORK_TYPE" , file_name : str ) -> None :
555
- """
556
- Pickles the self.classifier and stores it using the provided file_name in folder `art.config.ART_DATA_PATH`.
557
-
558
- :param classifier: Classifier to be pickled.
559
- :param file_name: Name of the file where the classifier will be pickled.
560
- """
561
- full_path = os .path .join (config .ART_DATA_PATH , file_name )
562
- folder = os .path .split (full_path )[0 ]
563
- if not os .path .exists (folder ):
564
- os .makedirs (folder )
565
-
566
- with open (full_path , "wb" ) as f_classifier :
567
- pickle .dump (classifier , f_classifier )
568
-
569
- @staticmethod
570
- def _unpickle_classifier (file_name : str ) -> "CLASSIFIER_NEURALNETWORK_TYPE" :
571
- """
572
- Unpickles classifier using the filename provided. Function assumes that the pickle is in
573
- `art.config.ART_DATA_PATH`.
574
-
575
- :param file_name: Path of the pickled classifier relative to `ART_DATA_PATH`.
576
- :return: The loaded classifier.
577
- """
578
- full_path = os .path .join (config .ART_DATA_PATH , file_name )
579
- logger .info ("Loading classifier from %s" , full_path )
580
- with open (full_path , "rb" ) as f_classifier :
581
- loaded_classifier = pickle .load (f_classifier )
582
- return loaded_classifier
583
-
584
- @staticmethod
585
- def _remove_pickle (file_name : str ) -> None :
586
- """
587
- Erases the pickle with the provided file name.
588
-
589
- :param file_name: File name without directory.
590
- """
591
- full_path = os .path .join (config .ART_DATA_PATH , file_name )
592
- os .remove (full_path )
593
-
594
573
def visualize_clusters (
595
574
self , x_raw : np .ndarray , save : bool = True , folder : str = "." , ** kwargs
596
575
) -> list [list [np .ndarray ]]:
0 commit comments