2626"""
2727from __future__ import absolute_import , division , print_function , unicode_literals , annotations
2828
29+ import copy
2930import logging
3031import os
31- import pickle
32- import time
3332from typing import Any , TYPE_CHECKING
3433
3534import numpy as np
3635
3736from sklearn .cluster import KMeans , MiniBatchKMeans
3837
39- from art import config
4038from art .data_generators import DataGenerator
4139from art .defences .detector .poison .clustering_analyzer import ClusteringAnalyzer
4240from art .defences .detector .poison .ground_truth_evaluator import GroundTruthEvaluator
@@ -468,12 +466,25 @@ def relabel_poison_ground_truth(
468466 x_train , x_test = x [:n_train ], x [n_train :]
469467 y_train , y_test = y_fix [:n_train ], y_fix [n_train :]
470468
471- filename = "original_classifier" + str (time .time ()) + ".p"
472- ActivationDefence ._pickle_classifier (classifier , filename )
469+ from tensorflow .keras .models import clone_model
470+
471+ model = classifier ._model
472+ forward_pass = classifier ._forward_pass
473+ classifier ._model = None
474+ classifier ._forward_pass = None
475+
476+ curr_classifier = copy .deepcopy (classifier )
477+ curr_model = clone_model (model )
478+ curr_model .set_weights (model .get_weights ())
479+ curr_classifier ._model = curr_model
480+ curr_classifier ._forward_pass = forward_pass
481+
482+ classifier ._model = model
483+ classifier ._forward_pass = forward_pass
473484
474485 # Now train using y_fix:
475486 improve_factor , _ = train_remove_backdoor (
476- classifier ,
487+ curr_classifier ,
477488 x_train ,
478489 y_train ,
479490 x_test ,
@@ -485,11 +496,9 @@ def relabel_poison_ground_truth(
485496
486497 # Only update classifier if there was an improvement:
487498 if improve_factor < 0 :
488- classifier = ActivationDefence ._unpickle_classifier (filename )
489499 return 0 , classifier
490500
491- ActivationDefence ._remove_pickle (filename )
492- return improve_factor , classifier
501+ return improve_factor , curr_classifier
493502
494503 @staticmethod
495504 def relabel_poison_cross_validation (
@@ -514,23 +523,35 @@ def relabel_poison_cross_validation(
514523 :param batch_epochs: Number of epochs to be trained before checking current state of model.
515524 :return: (improve_factor, classifier)
516525 """
517-
518526 from sklearn .model_selection import KFold
519527
520528 # Train using cross validation
521529 k_fold = KFold (n_splits = n_splits )
522530 KFold (n_splits = n_splits , random_state = None , shuffle = True )
523531
524- filename = "original_classifier" + str (time .time ()) + ".p"
525- ActivationDefence ._pickle_classifier (classifier , filename )
526532 curr_improvement = 0
527533
528534 for train_index , test_index in k_fold .split (x ):
529535 # Obtain partition:
530536 x_train , x_test = x [train_index ], x [test_index ]
531537 y_train , y_test = y_fix [train_index ], y_fix [test_index ]
532- # Unpickle original model:
533- curr_classifier = ActivationDefence ._unpickle_classifier (filename )
538+ # Copy original model:
539+
540+ from tensorflow .keras .models import clone_model
541+
542+ model = classifier ._model
543+ forward_pass = classifier ._forward_pass
544+ classifier ._model = None
545+ classifier ._forward_pass = None
546+
547+ curr_classifier = copy .deepcopy (classifier )
548+ curr_model = clone_model (model )
549+ curr_model .set_weights (model .get_weights ())
550+ curr_classifier ._model = curr_model
551+ curr_classifier ._forward_pass = forward_pass
552+
553+ classifier ._model = model
554+ classifier ._forward_pass = forward_pass
534555
535556 new_improvement , fixed_classifier = train_remove_backdoor (
536557 curr_classifier ,
@@ -547,50 +568,8 @@ def relabel_poison_cross_validation(
547568 classifier = fixed_classifier
548569 logger .info ("Selected as best model so far: %s" , curr_improvement )
549570
550- ActivationDefence ._remove_pickle (filename )
551571 return curr_improvement , classifier
552572
553- @staticmethod
554- def _pickle_classifier (classifier : "CLASSIFIER_NEURALNETWORK_TYPE" , file_name : str ) -> None :
555- """
556- Pickles the self.classifier and stores it using the provided file_name in folder `art.config.ART_DATA_PATH`.
557-
558- :param classifier: Classifier to be pickled.
559- :param file_name: Name of the file where the classifier will be pickled.
560- """
561- full_path = os .path .join (config .ART_DATA_PATH , file_name )
562- folder = os .path .split (full_path )[0 ]
563- if not os .path .exists (folder ):
564- os .makedirs (folder )
565-
566- with open (full_path , "wb" ) as f_classifier :
567- pickle .dump (classifier , f_classifier )
568-
569- @staticmethod
570- def _unpickle_classifier (file_name : str ) -> "CLASSIFIER_NEURALNETWORK_TYPE" :
571- """
572- Unpickles classifier using the filename provided. Function assumes that the pickle is in
573- `art.config.ART_DATA_PATH`.
574-
575- :param file_name: Path of the pickled classifier relative to `ART_DATA_PATH`.
576- :return: The loaded classifier.
577- """
578- full_path = os .path .join (config .ART_DATA_PATH , file_name )
579- logger .info ("Loading classifier from %s" , full_path )
580- with open (full_path , "rb" ) as f_classifier :
581- loaded_classifier = pickle .load (f_classifier )
582- return loaded_classifier
583-
584- @staticmethod
585- def _remove_pickle (file_name : str ) -> None :
586- """
587- Erases the pickle with the provided file name.
588-
589- :param file_name: File name without directory.
590- """
591- full_path = os .path .join (config .ART_DATA_PATH , file_name )
592- os .remove (full_path )
593-
594573 def visualize_clusters (
595574 self , x_raw : np .ndarray , save : bool = True , folder : str = "." , ** kwargs
596575 ) -> list [list [np .ndarray ]]:
0 commit comments