Skip to content

Commit 38a9f55

Browse files
committed
Format corrections, PerformanceMonitor cleanup, notebook inclusion and requirements cleanup.
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent e56b47a commit 38a9f55

File tree

6 files changed

+96
-341
lines changed

6 files changed

+96
-341
lines changed

art/defences/detector/poison/clustering_centroid_analysis.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,21 @@ def __init__(
109109
:param benign_indices: array data points' indices known to be benign
110110
:param final_feature_layer_name: the name of the final layer that builds feature representation. Must
111111
be a ReLu layer
112+
:param misclassification_threshold: maximum misclassification threshold to consider a cluster as clean
113+
:param reducer: dimensionality reducer used to reduce feature space. UMAP is used as default
114+
:param clusterer: clustering algorithm used to cluster the features. DBSCAN is used as default
115+
:param feature_extraction_batch_size: batch size for feature extraction.
116+
Use lower values in case of low GPU memory availability
117+
:param misclassification_batch_size: batch size for misclassification.
118+
Use lower values in case of low GPU memory
112119
"""
120+
# Tensorflow is imported dynamically to keep ART framework-agnostic
113121
try:
114122
import tensorflow as tf_runtime
115123
from tensorflow.keras import Model, Sequential
116124

125+
# Keep a runtime and function attributes to prevent garbage collection on these during the
126+
# algorithm lifespan
117127
self._tf_runtime = tf_runtime
118128
self._KerasModel = Model
119129
self._KerasSequential = Sequential
@@ -126,6 +136,7 @@ def __init__(
126136
from sklearn.base import ClusterMixin
127137
from sklearn.cluster import DBSCAN
128138

139+
# Default clusterer, recommended by original authors
129140
self.clusterer = DBSCAN(eps=0.8, min_samples=20)
130141
except ImportError as e:
131142
raise ImportError(
@@ -136,6 +147,7 @@ def __init__(
136147
try:
137148
from umap import UMAP
138149

150+
# Default reducer, recommended by original authors
139151
self.reducer = UMAP(n_neighbors=5, min_dist=0)
140152
except ImportError as e:
141153
raise ImportError(
@@ -158,53 +170,77 @@ def __init__(
158170
self.feature_extraction_batch_size = feature_extraction_batch_size
159171
self.misclassification_batch_size = misclassification_batch_size
160172

173+
# tf.data.Datasets are used as these are more efficient for TF-specific operations
161174
# Create x_train_dataset for feature extraction
162175
# Ensure data type is float32 for TensorFlow
163-
self.x_train_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
164-
self.x_train.astype(self._tf_runtime.float32.as_numpy_dtype)
165-
).batch(self.feature_extraction_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
176+
self.x_train_dataset = (
177+
self._tf_runtime.data.Dataset.from_tensor_slices(
178+
self.x_train.astype(self._tf_runtime.float32.as_numpy_dtype)
179+
)
180+
.batch(self.feature_extraction_batch_size)
181+
.prefetch(self._tf_runtime.data.AUTOTUNE)
182+
)
166183

167184
# Create x_benign_dataset for misclassification rate calculation
168185
# Ensure data type is float32 for TensorFlow
169-
self.x_benign_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
170-
self.x_benign_np.astype(self._tf_runtime.float32.as_numpy_dtype)
171-
).batch(self.misclassification_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
172-
# ---------------------------------------------
186+
self.x_benign_dataset = (
187+
self._tf_runtime.data.Dataset.from_tensor_slices(
188+
self.x_benign_np.astype(self._tf_runtime.float32.as_numpy_dtype)
189+
)
190+
.batch(self.misclassification_batch_size)
191+
.prefetch(self._tf_runtime.data.AUTOTUNE)
192+
)
173193

174194
self.feature_representation_model, self.classifying_submodel = self._extract_submodels(
175195
final_feature_layer_name
176196
)
177197

178198
self.misclassification_threshold = np.float64(misclassification_threshold)
179199

180-
# --- Get feature_shape and deviation_shape BEFORE wrapping functions ---
181-
if hasattr(self.feature_representation_model, 'output_shape') and self.feature_representation_model.output_shape[1:] is not None:
200+
# Get feature_shape and deviation_shape before wrapping functions
201+
# It's fundamental to wrap functions with input signatures
202+
if (
203+
hasattr(self.feature_representation_model, "output_shape")
204+
and self.feature_representation_model.output_shape[1:] is not None
205+
):
182206
# Assumes output_shape is something like (None, 5) or (None, 28, 28, 3)
183207
self.feature_output_shape = self.feature_representation_model.output_shape[1:]
184208
# Handle case where output_shape[1:] might be an empty tuple if model outputs a scalar.
185209
# Unlikely for feature extractor, but good to be robust.
186210
if not self.feature_output_shape:
187-
logging.warning("feature_representation_model.output_shape[1:] is empty. Attempting dummy inference.")
211+
logging.warning(
212+
"feature_representation_model.output_shape[1:] is empty. Attempting dummy inference."
213+
)
188214
# Fallback to dummy inference if output_shape is unexpectedly empty
189-
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
190-
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
215+
dummy_input_shape = (
216+
1,
217+
*self.x_train.shape[1:],
218+
) # Use actual classifier input shape
219+
dummy_input = self._tf_runtime.random.uniform(
220+
shape=dummy_input_shape, dtype=self._tf_runtime.float32
221+
)
191222
dummy_features = self.feature_representation_model(dummy_input)
192223
self.feature_output_shape = dummy_features.shape[1:]
193224
else:
194225
# Fallback if output_shape attribute is not present or usable
195-
logging.warning("feature_representation_model.output_shape not directly available. Performing dummy inference for shape.")
196-
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
197-
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
226+
logging.warning(
227+
"feature_representation_model.output_shape not directly available. "
228+
"Performing dummy inference for shape."
229+
)
230+
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
231+
dummy_input = self._tf_runtime.random.uniform(
232+
shape=dummy_input_shape, dtype=self._tf_runtime.float32
233+
)
198234
dummy_features = self.feature_representation_model(dummy_input)
199235
self.feature_output_shape = dummy_features.shape[1:]
200236

201237
# If after all attempts, feature_output_shape is still empty/invalid, raise an error
202238
if not self.feature_output_shape or any(d is None for d in self.feature_output_shape):
203239
raise RuntimeError(
204-
f"Could not determine a valid feature_output_shape ({self.feature_output_shape}) for tf.function input_signature. "
240+
f"Could not determine a valid feature_output_shape ({self.feature_output_shape}) for "
241+
f"tf.function input_signature. "
205242
"Ensure feature_representation_model is built and outputs a known shape."
206243
)
207-
# --- END SHAPE DETERMINATION ---
208244

209245
# Dynamic @tf.function wrapping
210246
self._calculate_centroid_tf = self._tf_runtime.function(
@@ -215,8 +251,12 @@ def __init__(
215251
self._predict_with_deviation = self._tf_runtime.function(
216252
self._predict_with_deviation_unwrapped,
217253
input_signature=[
218-
self._tf_runtime.TensorSpec(shape=[None, *self.feature_output_shape], dtype=self._tf_runtime.float32),
219-
self._tf_runtime.TensorSpec(shape=self.feature_output_shape, dtype=self._tf_runtime.float32),
254+
self._tf_runtime.TensorSpec(
255+
shape=[None, *self.feature_output_shape], dtype=self._tf_runtime.float32
256+
),
257+
self._tf_runtime.TensorSpec(
258+
shape=self.feature_output_shape, dtype=self._tf_runtime.float32
259+
),
220260
],
221261
reduce_retracing=True,
222262
)
@@ -288,7 +328,7 @@ def _feature_extraction(
288328
# Iterate directly over the provided input_dataset
289329
for batch_data_tf in x_train:
290330
batch_features = self._calculate_features(feature_representation_model, batch_data_tf)
291-
features_np_list.append(batch_features.numpy()) # Convert to NumPy immediately
331+
features_np_list.append(batch_features.numpy()) # Convert to NumPy immediately
292332

293333
# Concatenate all batches of numpy arrays on CPU (system RAM)
294334
final_features_np: np.ndarray = np.concatenate(features_np_list, axis=0)
@@ -441,10 +481,10 @@ def evaluate_defence(self, is_clean: np.ndarray, **kwargs) -> str:
441481

442482
return confusion_matrix_json
443483

444-
def _predict_with_deviation_unwrapped(self, features: tf_types.Tensor, deviation: tf_types.Tensor) -> tf_types.Tensor:
445-
# Add deviation to features and pass through ReLu to keep in latent space
484+
def _predict_with_deviation_unwrapped(
485+
self, features: tf_types.Tensor, deviation: tf_types.Tensor
486+
) -> tf_types.Tensor:
446487
deviated_features = self._tf_runtime.nn.relu(features + deviation)
447-
# Get predictions from classifying submodel
448488
return self.classifying_submodel(deviated_features, training=False)
449489

450490
def _calculate_misclassification_rate(
@@ -462,7 +502,6 @@ def _calculate_misclassification_rate(
462502

463503
total_elements = 0
464504
misclassified_elements = 0
465-
all_features_np: list[np.ndarray] = []
466505

467506
# Get all classes except the current one
468507
other_classes = self.unique_classes - {class_label}
@@ -476,27 +515,27 @@ def _calculate_misclassification_rate(
476515
if len(other_class_data) == 0:
477516
continue
478517

479-
class_x_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
480-
other_class_data.astype(self._tf_runtime.float32.as_numpy_dtype)
481-
).batch(self.misclassification_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
518+
class_x_dataset = (
519+
self._tf_runtime.data.Dataset.from_tensor_slices(
520+
other_class_data.astype(self._tf_runtime.float32.as_numpy_dtype)
521+
)
522+
.batch(self.misclassification_batch_size)
523+
.prefetch(self._tf_runtime.data.AUTOTUNE)
524+
)
482525

483526
total_elements += len(other_class_data)
484527
class_misclassified = 0
485528

529+
# Batches of the class are processed with deviation to determine misclassification
486530
for batch_data_tf in class_x_dataset:
487-
# Extract features
488531
features_tf = self._calculate_features(
489532
self.feature_representation_model, batch_data_tf
490533
)
491-
all_features_np.append(features_tf.numpy())
492534

493-
# Get predictions with deviation
494535
predictions = self._predict_with_deviation(features_tf, deviation_tf)
495536

496-
# Convert predictions to class indices
497537
pred_classes_np = self._tf_runtime.argmax(predictions, axis=1).numpy()
498538

499-
# Count misclassifications (predicted as class_label)
500539
batch_misclassified = np.sum(pred_classes_np == class_label)
501540
class_misclassified += batch_misclassified
502541

@@ -506,11 +545,9 @@ def _calculate_misclassification_rate(
506545
if total_elements == 0:
507546
return np.float64(0.0)
508547

509-
all_f_vectors_np = np.concatenate(all_features_np, axis=0)
510548
logging.debug(
511-
"MR --> %s , |f| = %s: %s / %s = %s",
549+
"MR --> %s : %s / %s = %s",
512550
class_label,
513-
np.linalg.norm(np.mean(all_f_vectors_np, axis=0)),
514551
misclassified_elements,
515552
total_elements,
516553
np.float64(misclassified_elements) / np.float64(total_elements),
@@ -519,14 +556,16 @@ def _calculate_misclassification_rate(
519556
return np.float64(misclassified_elements) / np.float64(total_elements)
520557

521558
def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
522-
# saves important information about the algorithm execution for further analysis
559+
# Saves important information about the algorithm execution for further analysis
523560
report: dict[str, Any] = {}
524561

525562
self.is_clean_np = np.ones(len(self.y_train))
526563

527-
self.features = self._feature_extraction(self.x_train_dataset, self.feature_representation_model)
564+
self.features = self._feature_extraction(
565+
self.x_train_dataset, self.feature_representation_model
566+
)
528567

529-
# FIXME: temporal fix to test other layers
568+
# Small fix to add flexibility to use CCAUD in non-flattened scenarios. Not recommended, but available
530569
if len(self.features.shape) > 2:
531570
num_samples = self.features.shape[0]
532571
self.features = self.features.reshape(num_samples, -1) # Flattening

0 commit comments

Comments
 (0)