Skip to content

Commit f4980b5

Browse files
authored
Merge pull request #1738 from keykholt/ex_reclass
Exclusionary Reclassification
2 parents 73c41b5 + 702e053 commit f4980b5

File tree

5 files changed

+452
-121
lines changed

5 files changed

+452
-121
lines changed

art/defences/detector/poison/activation_defence.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ class ActivationDefence(PoisonFilteringDefence):
6161
in general, see https://arxiv.org/abs/1902.06705
6262
"""
6363

64-
defence_params = ["nb_clusters", "clustering_method", "nb_dims", "reduce", "cluster_analysis", "generator"]
64+
defence_params = [
65+
"nb_clusters",
66+
"clustering_method",
67+
"nb_dims",
68+
"reduce",
69+
"cluster_analysis",
70+
"generator",
71+
"ex_re_threshold",
72+
]
6573
valid_clustering = ["KMeans"]
6674
valid_reduce = ["PCA", "FastICA", "TSNE"]
6775
valid_analysis = ["smaller", "distance", "relative-size", "silhouette-scores"]
@@ -74,6 +82,7 @@ def __init__(
7482
x_train: np.ndarray,
7583
y_train: np.ndarray,
7684
generator: Optional[DataGenerator] = None,
85+
ex_re_threshold: Optional[float] = None,
7786
) -> None:
7887
"""
7988
Create an :class:`.ActivationDefence` object with the provided classifier.
@@ -82,6 +91,7 @@ def __init__(
8291
:param x_train: A dataset used to train the classifier.
8392
:param y_train: Labels used to train the classifier.
8493
:param generator: A data generator to be used instead of `x_train` and `y_train`.
94+
:param ex_re_threshold: Set to a positive value to enable exclusionary reclassification
8595
"""
8696
super().__init__(classifier, x_train, y_train)
8797
self.classifier: "CLASSIFIER_NEURALNETWORK_TYPE" = classifier
@@ -102,6 +112,7 @@ def __init__(
102112
self.confidence_level: List[float] = []
103113
self.poisonous_clusters: np.ndarray
104114
self.clusterer = MiniBatchKMeans(n_clusters=self.nb_clusters)
115+
self.ex_re_threshold = ex_re_threshold
105116
self._check_params()
106117

107118
def evaluate_defence(self, is_clean: np.ndarray, **kwargs) -> str:
@@ -221,6 +232,14 @@ def detect_poison(self, **kwargs) -> Tuple[Dict[str, Any], List[int]]:
221232
if assignment == 1:
222233
self.is_clean_lst[index_dp] = 1
223234

235+
if self.ex_re_threshold is not None:
236+
if self.generator is not None:
237+
raise RuntimeError("Currently, exclusionary reclassification cannot be used with generators")
238+
if hasattr(self.classifier, "clone_for_refitting"):
239+
report = self.exclusionary_reclassification(report)
240+
else:
241+
logger.warning("Classifier does not have clone_for_refitting method defined. Skipping")
242+
224243
return report, self.is_clean_lst
225244

226245
def cluster_activations(self, **kwargs) -> Tuple[List[np.ndarray], List[np.ndarray]]:
@@ -331,6 +350,86 @@ def analyze_clusters(self, **kwargs) -> Tuple[Dict[str, Any], np.ndarray]:
331350

332351
return report, self.assigned_clean_by_class
333352

353+
def exclusionary_reclassification(self, report: Dict[str, Any]):
354+
"""
355+
This function perform exclusionary reclassification. Based on the ex_re_threshold,
356+
suspicious clusters will be rechecked. If they remain suspicious, the suspected source
357+
class will be added to the report and the data will be relabelled. The new labels are stored
358+
in self.y_train_relabelled
359+
360+
:param report: A dictionary containing defence params as well as the class clusters and their suspiciousness.
361+
:return: report where the report is a dict object
362+
"""
363+
self.y_train_relabelled = np.copy(self.y_train) # Copy the data to avoid overwriting user objects
364+
# used for relabeling the data
365+
is_onehot = False
366+
if len(np.shape(self.y_train)) == 2:
367+
is_onehot = True
368+
369+
logger.info("Performing Exclusionary Reclassification with a threshold of %s", self.ex_re_threshold)
370+
logger.info("Data will be relabelled internally. Access the y_train_relabelled attribute to get new labels")
371+
# Train a new classifier with the unsuspicious clusters
372+
cloned_classifier = (
373+
self.classifier.clone_for_refitting()
374+
) # Get a classifier with the same training setup, but new weights
375+
filtered_x = self.x_train[np.array(self.is_clean_lst) == 1]
376+
filtered_y = self.y_train[np.array(self.is_clean_lst) == 1]
377+
378+
if len(filtered_x) == 0:
379+
logger.warning("All of the data is marked as suspicious. Unable to perform exclusionary reclassification")
380+
return report
381+
382+
cloned_classifier.fit(filtered_x, filtered_y)
383+
384+
# Test on the suspicious clusters
385+
n_train = len(self.x_train)
386+
indices_by_class = self._segment_by_class(np.arange(n_train), self.y_train)
387+
indicies_by_cluster: List[List[List]] = [
388+
[[] for _ in range(self.nb_clusters)] for _ in range(self.classifier.nb_classes)
389+
]
390+
391+
# Get all data in x_train in the right cluster
392+
for n_class, cluster_assignments in enumerate(self.clusters_by_class):
393+
for j, assigned_cluster in enumerate(cluster_assignments):
394+
indicies_by_cluster[n_class][assigned_cluster].append(indices_by_class[n_class][j])
395+
396+
for n_class, _ in enumerate(self.poisonous_clusters):
397+
suspicious_clusters = np.where(np.array(self.poisonous_clusters[n_class]) == 1)[0]
398+
for cluster in suspicious_clusters:
399+
cur_indicies = indicies_by_cluster[n_class][cluster]
400+
predictions = cloned_classifier.predict(self.x_train[cur_indicies])
401+
402+
predicted_as_class = [
403+
np.sum(np.argmax(predictions, axis=1) == i) for i in range(self.classifier.nb_classes)
404+
]
405+
n_class_pred_count = predicted_as_class[n_class]
406+
predicted_as_class[n_class] = -1 * predicted_as_class[n_class] # Just to make the max simple
407+
other_class = np.argmax(predicted_as_class)
408+
other_class_pred_count = predicted_as_class[other_class]
409+
410+
# Check if cluster is legit. If so, mark it as such
411+
if other_class_pred_count == 0 or n_class_pred_count / other_class_pred_count > self.ex_re_threshold:
412+
self.poisonous_clusters[n_class][cluster] = 0
413+
report["Class_" + str(n_class)]["cluster_" + str(cluster)]["suspicious_cluster"] = False
414+
if "suspicious_clusters" in report.keys():
415+
report["suspicious_clusters"] = report["suspicious_clusters"] - 1
416+
for ind in cur_indicies:
417+
self.is_clean_lst[ind] = 1
418+
# Otherwise, add the exclusionary reclassification info to the report for the suspicious cluster
419+
else:
420+
report["Class_" + str(n_class)]["cluster_" + str(cluster)]["ExRe_Score"] = (
421+
n_class_pred_count / other_class_pred_count
422+
)
423+
report["Class_" + str(n_class)]["cluster_" + str(cluster)]["Suspected_Source_class"] = other_class
424+
# Also relabel the data
425+
if is_onehot:
426+
self.y_train_relabelled[cur_indicies, n_class] = 0
427+
self.y_train_relabelled[cur_indicies, other_class] = 1
428+
else:
429+
self.y_train_relabelled[cur_indicies] = other_class
430+
431+
return report
432+
334433
@staticmethod
335434
def relabel_poison_ground_truth(
336435
classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
@@ -572,6 +671,8 @@ def _check_params(self):
572671
raise ValueError("Unsupported method for cluster analysis method: " + self.cluster_analysis)
573672
if self.generator and not isinstance(self.generator, DataGenerator):
574673
raise TypeError("Generator must a an instance of DataGenerator")
674+
if self.ex_re_threshold is not None and self.ex_re_threshold <= 0:
675+
raise ValueError("Exclusionary reclassification threshold must be positive")
575676

576677
def _get_activations(self, x_train: Optional[np.ndarray] = None) -> np.ndarray:
577678
"""
@@ -596,7 +697,7 @@ def _get_activations(self, x_train: Optional[np.ndarray] = None) -> np.ndarray:
596697
if isinstance(activations, np.ndarray):
597698
nodes_last_layer = np.shape(activations)[1]
598699
else:
599-
raise ValueError("`activations is None or tensor.")
700+
raise ValueError("activations is None or tensor.")
600701

601702
if nodes_last_layer <= self.TOO_SMALL_ACTIVATIONS:
602703
logger.warning(
@@ -703,7 +804,7 @@ def cluster_activations(
703804
if clustering_method == "KMeans":
704805
clusterer = KMeans(n_clusters=nb_clusters)
705806
else:
706-
raise ValueError(clustering_method + " clustering method not supported.")
807+
raise ValueError(f"{clustering_method} clustering method not supported.")
707808

708809
for activation in separated_activations:
709810
# Apply dimensionality reduction
@@ -749,7 +850,7 @@ def reduce_dimensionality(activations: np.ndarray, nb_dims: int = 10, reduce: st
749850
elif reduce == "PCA":
750851
projector = PCA(n_components=nb_dims)
751852
else:
752-
raise ValueError(reduce + " dimensionality reduction method not supported.")
853+
raise ValueError(f"{reduce} dimensionality reduction method not supported.")
753854

754855
reduced_activations = projector.fit_transform(activations)
755856
return reduced_activations

art/estimators/classification/classifier.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
This module implements mixin abstract base classes defining properties for all classifiers in ART.
2020
"""
2121
from abc import ABC, ABCMeta, abstractmethod
22-
from typing import List, Optional, Union
22+
from typing import List, Optional, TYPE_CHECKING, Union
2323

2424
import numpy as np
2525

@@ -30,6 +30,9 @@
3030
DecisionTreeMixin,
3131
)
3232

33+
if TYPE_CHECKING:
34+
from art.utils import CLASSIFIER_TYPE
35+
3336

3437
class InputFilter(ABCMeta):
3538
"""
@@ -117,6 +120,12 @@ def nb_classes(self, nb_classes: int):
117120

118121
self._nb_classes = nb_classes
119122

123+
def clone_for_refitting(self) -> "CLASSIFIER_TYPE":
124+
"""
125+
Clone classifier for refitting.
126+
"""
127+
raise NotImplementedError
128+
120129

121130
class ClassGradientsMixin(ABC):
122131
"""

art/estimators/classification/keras.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,52 @@ def custom_loss_gradient(self, nn_function, tensors, input_values, name="default
715715
outputs = self._custom_loss_func[name]
716716
return outputs(input_values)
717717

718+
def clone_for_refitting(
719+
self,
720+
) -> "KerasClassifier": # lgtm [py/inheritance/incorrect-overridden-signature]
721+
"""
722+
Create a copy of the classifier that can be refit from scratch. Will inherit same architecture, optimizer and
723+
initialization as cloned model, but without weights.
724+
725+
:return: new estimator
726+
"""
727+
728+
import tensorflow as tf # lgtm [py/repeated-import]
729+
import keras # lgtm [py/repeated-import]
730+
731+
try:
732+
# only works for functionally defined models
733+
model = keras.models.clone_model(self.model, input_tensors=self.model.inputs)
734+
except ValueError as error:
735+
raise ValueError("Cannot clone custom models") from error
736+
737+
optimizer = self.model.optimizer
738+
# reset optimizer variables
739+
for var in optimizer.variables():
740+
var.assign(tf.zeros_like(var))
741+
742+
loss_weights = None
743+
weighted_metrics = None
744+
if self.model.compiled_loss:
745+
loss_weights = self.model.compiled_loss._loss_weights # pylint: disable=W0212
746+
if self.model.compiled_metrics:
747+
weighted_metrics = self.model.compiled_metrics._weighted_metrics # pylint: disable=W0212
748+
749+
model.compile(
750+
optimizer=optimizer,
751+
loss=self.model.loss,
752+
metrics=[m.name for m in self.model.metrics], # Need to copy metrics this way for keras
753+
loss_weights=loss_weights,
754+
weighted_metrics=weighted_metrics,
755+
run_eagerly=self.model.run_eagerly,
756+
)
757+
758+
clone = type(self)(model)
759+
params = self.get_params()
760+
del params["model"]
761+
clone.set_params(**params)
762+
return clone
763+
718764
def _init_class_gradients(self, label: Optional[Union[int, List[int], np.ndarray]] = None) -> None:
719765
# pylint: disable=E0401
720766
if self.is_tensorflow:

0 commit comments

Comments
 (0)