@@ -61,7 +61,15 @@ class ActivationDefence(PoisonFilteringDefence):
6161 in general, see https://arxiv.org/abs/1902.06705
6262 """
6363
64- defence_params = ["nb_clusters" , "clustering_method" , "nb_dims" , "reduce" , "cluster_analysis" , "generator" ]
64+ defence_params = [
65+ "nb_clusters" ,
66+ "clustering_method" ,
67+ "nb_dims" ,
68+ "reduce" ,
69+ "cluster_analysis" ,
70+ "generator" ,
71+ "ex_re_threshold" ,
72+ ]
6573 valid_clustering = ["KMeans" ]
6674 valid_reduce = ["PCA" , "FastICA" , "TSNE" ]
6775 valid_analysis = ["smaller" , "distance" , "relative-size" , "silhouette-scores" ]
@@ -74,6 +82,7 @@ def __init__(
7482 x_train : np .ndarray ,
7583 y_train : np .ndarray ,
7684 generator : Optional [DataGenerator ] = None ,
85+ ex_re_threshold : Optional [float ] = None ,
7786 ) -> None :
7887 """
7988 Create an :class:`.ActivationDefence` object with the provided classifier.
@@ -82,6 +91,7 @@ def __init__(
8291 :param x_train: A dataset used to train the classifier.
8392 :param y_train: Labels used to train the classifier.
8493 :param generator: A data generator to be used instead of `x_train` and `y_train`.
94+ :param ex_re_threshold: Set to a positive value to enable exclusionary reclassification
8595 """
8696 super ().__init__ (classifier , x_train , y_train )
8797 self .classifier : "CLASSIFIER_NEURALNETWORK_TYPE" = classifier
@@ -102,6 +112,7 @@ def __init__(
102112 self .confidence_level : List [float ] = []
103113 self .poisonous_clusters : np .ndarray
104114 self .clusterer = MiniBatchKMeans (n_clusters = self .nb_clusters )
115+ self .ex_re_threshold = ex_re_threshold
105116 self ._check_params ()
106117
107118 def evaluate_defence (self , is_clean : np .ndarray , ** kwargs ) -> str :
@@ -221,6 +232,14 @@ def detect_poison(self, **kwargs) -> Tuple[Dict[str, Any], List[int]]:
221232 if assignment == 1 :
222233 self .is_clean_lst [index_dp ] = 1
223234
235+ if self .ex_re_threshold is not None :
236+ if self .generator is not None :
237+ raise RuntimeError ("Currently, exclusionary reclassification cannot be used with generators" )
238+ if hasattr (self .classifier , "clone_for_refitting" ):
239+ report = self .exclusionary_reclassification (report )
240+ else :
241+ logger .warning ("Classifier does not have clone_for_refitting method defined. Skipping" )
242+
224243 return report , self .is_clean_lst
225244
226245 def cluster_activations (self , ** kwargs ) -> Tuple [List [np .ndarray ], List [np .ndarray ]]:
@@ -331,6 +350,86 @@ def analyze_clusters(self, **kwargs) -> Tuple[Dict[str, Any], np.ndarray]:
331350
332351 return report , self .assigned_clean_by_class
333352
353+ def exclusionary_reclassification (self , report : Dict [str , Any ]):
354+ """
355+ This function perform exclusionary reclassification. Based on the ex_re_threshold,
356+ suspicious clusters will be rechecked. If they remain suspicious, the suspected source
357+ class will be added to the report and the data will be relabelled. The new labels are stored
358+ in self.y_train_relabelled
359+
360+ :param report: A dictionary containing defence params as well as the class clusters and their suspiciousness.
361+ :return: report where the report is a dict object
362+ """
363+ self .y_train_relabelled = np .copy (self .y_train ) # Copy the data to avoid overwriting user objects
364+ # used for relabeling the data
365+ is_onehot = False
366+ if len (np .shape (self .y_train )) == 2 :
367+ is_onehot = True
368+
369+ logger .info ("Performing Exclusionary Reclassification with a threshold of %s" , self .ex_re_threshold )
370+ logger .info ("Data will be relabelled internally. Access the y_train_relabelled attribute to get new labels" )
371+ # Train a new classifier with the unsuspicious clusters
372+ cloned_classifier = (
373+ self .classifier .clone_for_refitting ()
374+ ) # Get a classifier with the same training setup, but new weights
375+ filtered_x = self .x_train [np .array (self .is_clean_lst ) == 1 ]
376+ filtered_y = self .y_train [np .array (self .is_clean_lst ) == 1 ]
377+
378+ if len (filtered_x ) == 0 :
379+ logger .warning ("All of the data is marked as suspicious. Unable to perform exclusionary reclassification" )
380+ return report
381+
382+ cloned_classifier .fit (filtered_x , filtered_y )
383+
384+ # Test on the suspicious clusters
385+ n_train = len (self .x_train )
386+ indices_by_class = self ._segment_by_class (np .arange (n_train ), self .y_train )
387+ indicies_by_cluster : List [List [List ]] = [
388+ [[] for _ in range (self .nb_clusters )] for _ in range (self .classifier .nb_classes )
389+ ]
390+
391+ # Get all data in x_train in the right cluster
392+ for n_class , cluster_assignments in enumerate (self .clusters_by_class ):
393+ for j , assigned_cluster in enumerate (cluster_assignments ):
394+ indicies_by_cluster [n_class ][assigned_cluster ].append (indices_by_class [n_class ][j ])
395+
396+ for n_class , _ in enumerate (self .poisonous_clusters ):
397+ suspicious_clusters = np .where (np .array (self .poisonous_clusters [n_class ]) == 1 )[0 ]
398+ for cluster in suspicious_clusters :
399+ cur_indicies = indicies_by_cluster [n_class ][cluster ]
400+ predictions = cloned_classifier .predict (self .x_train [cur_indicies ])
401+
402+ predicted_as_class = [
403+ np .sum (np .argmax (predictions , axis = 1 ) == i ) for i in range (self .classifier .nb_classes )
404+ ]
405+ n_class_pred_count = predicted_as_class [n_class ]
406+ predicted_as_class [n_class ] = - 1 * predicted_as_class [n_class ] # Just to make the max simple
407+ other_class = np .argmax (predicted_as_class )
408+ other_class_pred_count = predicted_as_class [other_class ]
409+
410+ # Check if cluster is legit. If so, mark it as such
411+ if other_class_pred_count == 0 or n_class_pred_count / other_class_pred_count > self .ex_re_threshold :
412+ self .poisonous_clusters [n_class ][cluster ] = 0
413+ report ["Class_" + str (n_class )]["cluster_" + str (cluster )]["suspicious_cluster" ] = False
414+ if "suspicious_clusters" in report .keys ():
415+ report ["suspicious_clusters" ] = report ["suspicious_clusters" ] - 1
416+ for ind in cur_indicies :
417+ self .is_clean_lst [ind ] = 1
418+ # Otherwise, add the exclusionary reclassification info to the report for the suspicious cluster
419+ else :
420+ report ["Class_" + str (n_class )]["cluster_" + str (cluster )]["ExRe_Score" ] = (
421+ n_class_pred_count / other_class_pred_count
422+ )
423+ report ["Class_" + str (n_class )]["cluster_" + str (cluster )]["Suspected_Source_class" ] = other_class
424+ # Also relabel the data
425+ if is_onehot :
426+ self .y_train_relabelled [cur_indicies , n_class ] = 0
427+ self .y_train_relabelled [cur_indicies , other_class ] = 1
428+ else :
429+ self .y_train_relabelled [cur_indicies ] = other_class
430+
431+ return report
432+
334433 @staticmethod
335434 def relabel_poison_ground_truth (
336435 classifier : "CLASSIFIER_NEURALNETWORK_TYPE" ,
@@ -572,6 +671,8 @@ def _check_params(self):
572671 raise ValueError ("Unsupported method for cluster analysis method: " + self .cluster_analysis )
573672 if self .generator and not isinstance (self .generator , DataGenerator ):
574673 raise TypeError ("Generator must a an instance of DataGenerator" )
674+ if self .ex_re_threshold is not None and self .ex_re_threshold <= 0 :
675+ raise ValueError ("Exclusionary reclassification threshold must be positive" )
575676
576677 def _get_activations (self , x_train : Optional [np .ndarray ] = None ) -> np .ndarray :
577678 """
@@ -596,7 +697,7 @@ def _get_activations(self, x_train: Optional[np.ndarray] = None) -> np.ndarray:
596697 if isinstance (activations , np .ndarray ):
597698 nodes_last_layer = np .shape (activations )[1 ]
598699 else :
599- raise ValueError ("` activations is None or tensor." )
700+ raise ValueError ("activations is None or tensor." )
600701
601702 if nodes_last_layer <= self .TOO_SMALL_ACTIVATIONS :
602703 logger .warning (
@@ -703,7 +804,7 @@ def cluster_activations(
703804 if clustering_method == "KMeans" :
704805 clusterer = KMeans (n_clusters = nb_clusters )
705806 else :
706- raise ValueError (clustering_method + " clustering method not supported." )
807+ raise ValueError (f" { clustering_method } clustering method not supported." )
707808
708809 for activation in separated_activations :
709810 # Apply dimensionality reduction
@@ -749,7 +850,7 @@ def reduce_dimensionality(activations: np.ndarray, nb_dims: int = 10, reduce: st
749850 elif reduce == "PCA" :
750851 projector = PCA (n_components = nb_dims )
751852 else :
752- raise ValueError (reduce + " dimensionality reduction method not supported." )
853+ raise ValueError (f" { reduce } dimensionality reduction method not supported." )
753854
754855 reduced_activations = projector .fit_transform (activations )
755856 return reduced_activations
0 commit comments