Skip to content

Commit c1e3f5d

Browse files
committed
Improved feature extraction with TF Dataset
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent 7bf76df commit c1e3f5d

File tree

2 files changed

+160
-121
lines changed

2 files changed

+160
-121
lines changed

art/defences/detector/poison/clustering_centroid_analysis.py

Lines changed: 88 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import logging
2727
import warnings
28-
from typing import TYPE_CHECKING, Any
28+
from typing import TYPE_CHECKING, Any, Callable
2929

3030
import numpy as np
3131

@@ -46,6 +46,7 @@
4646
tf_types = Any
4747
ClassifierType = Any
4848

49+
4950
def _encode_labels(y: np.ndarray) -> tuple[np.ndarray, set, np.ndarray, dict]:
5051
"""
5152
Given the target column, it generates the label encoding and the reverse mapping to use in the
@@ -87,15 +88,17 @@ class ClusteringCentroidAnalysisTensorFlowV2(PoisonFilteringDefence):
8788
valid_reduce = ["UMAP"]
8889

8990
def __init__(
90-
self,
91-
classifier: CLASSIFIER_TYPE,
92-
x_train: np.ndarray,
93-
y_train: np.ndarray,
94-
benign_indices: np.ndarray,
95-
final_feature_layer_name: str,
96-
misclassification_threshold: float,
97-
reducer: UMAP | None = None,
98-
clusterer: ClusterMixin | None = None,
91+
self,
92+
classifier: CLASSIFIER_TYPE,
93+
x_train: np.ndarray,
94+
y_train: np.ndarray,
95+
benign_indices: np.ndarray,
96+
final_feature_layer_name: str,
97+
misclassification_threshold: float,
98+
reducer: UMAP | None = None,
99+
clusterer: ClusterMixin | None = None,
100+
feature_extraction_batch_size: int = 32,
101+
misclassification_batch_size: int = 32,
99102
):
100103
"""
101104
Creates a :class: `ClusteringCentroidAnalysis` object for the given classifier
@@ -110,19 +113,19 @@ def __init__(
110113
try:
111114
import tensorflow as tf_runtime
112115
from tensorflow.keras import Model, Sequential
116+
113117
self._tf_runtime = tf_runtime
114118
self._KerasModel = Model
115119
self._KerasSequential = Sequential
116120
self._tf_runtime.get_logger().setLevel(logging.WARN)
117121
except ImportError as e:
118-
raise ImportError(
119-
"TensorFlow is required for ClusteringCentroidAnalysis. "
120-
) from e
122+
raise ImportError("TensorFlow is required for ClusteringCentroidAnalysis. ") from e
121123

122124
if clusterer is None:
123125
try:
124126
from sklearn.base import ClusterMixin
125127
from sklearn.cluster import DBSCAN
128+
126129
self.clusterer = DBSCAN(eps=0.8, min_samples=20)
127130
except ImportError as e:
128131
raise ImportError(
@@ -132,6 +135,7 @@ def __init__(
132135
if reducer is None:
133136
try:
134137
from umap import UMAP
138+
135139
self.reducer = UMAP(n_neighbors=5, min_dist=0)
136140
except ImportError as e:
137141
raise ImportError(
@@ -142,13 +146,30 @@ def __init__(
142146
super().__init__(classifier, x_train, y_train)
143147
self.benign_indices = benign_indices
144148
(
145-
self.y_train,
149+
self.y_train_np,
146150
self.unique_classes,
147151
self.class_mapping,
148152
self.reverse_class_mapping,
149153
) = _encode_labels(y_train)
150154

151-
self.x_benign, self.y_benign = self._get_benign_data()
155+
# Data is loaded as NumPy arrays first
156+
self.x_benign_np, self.y_benign_np = self._get_benign_data()
157+
158+
self.feature_extraction_batch_size = feature_extraction_batch_size
159+
self.misclassification_batch_size = misclassification_batch_size
160+
161+
# Create x_train_dataset for feature extraction
162+
# Ensure data type is float32 for TensorFlow
163+
self.x_train_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
164+
self.x_train.astype(self._tf_runtime.float32.as_numpy_dtype)
165+
).batch(self.feature_extraction_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
166+
167+
# Create x_benign_dataset for misclassification rate calculation
168+
# Ensure data type is float32 for TensorFlow
169+
self.x_benign_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
170+
self.x_benign_np.astype(self._tf_runtime.float32.as_numpy_dtype)
171+
).batch(self.misclassification_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
172+
# ---------------------------------------------
152173

153174
self.feature_representation_model, self.classifying_submodel = self._extract_submodels(
154175
final_feature_layer_name
@@ -157,19 +178,16 @@ def __init__(
157178
self.misclassification_threshold = np.float64(misclassification_threshold)
158179

159180
# Dynamic @tf.function wrapping
160-
self._calculate_centroid_tf_original = self._calculate_centroid_tf
161-
self._calculate_features_original = self._calculate_features
162-
163-
self._calculate_centroid_tf = self._tf_runtime.function(self._calculate_centroid_tf_original, reduce_retracing=True)
164-
self._calculate_features = self._tf_runtime.function(self._calculate_features_original)
165-
181+
self._calculate_centroid_tf = self._tf_runtime.function(
182+
self._calculate_centroid_tf_unwrapped, reduce_retracing=True
183+
)
184+
self._calculate_features = self._tf_runtime.function(self._calculate_features_unwrapped)
166185

167186
logging.info("CCA object created successfully.")
168187

169-
def _calculate_centroid_tf(self, features):
188+
def _calculate_centroid_tf_unwrapped(self, features):
170189
return self._tf_runtime.reduce_mean(features, axis=0)
171190

172-
173191
def _calculate_centroid(self, selected_indices: np.ndarray, features: np.ndarray) -> np.ndarray:
174192
"""
175193
Returns the centroid of all data within a specific cluster that is classified as a specific class label
@@ -179,12 +197,15 @@ def _calculate_centroid(self, selected_indices: np.ndarray, features: np.ndarray
179197
:return: d-dimensional numpy array
180198
"""
181199
selected_features = features[selected_indices]
182-
features_tf = self._tf_runtime.convert_to_tensor(selected_features, dtype=self._tf_runtime.float32)
200+
features_tf = self._tf_runtime.convert_to_tensor(
201+
selected_features, dtype=self._tf_runtime.float32
202+
)
183203
centroid = self._calculate_centroid_tf(features_tf)
184204
return centroid.numpy()
185205

186-
187-
def _class_clustering(self, y: np.ndarray, features: np.ndarray, label: int | str, clusterer: ClusterMixin) -> tuple[np.ndarray, np.ndarray]:
206+
def _class_clustering(
207+
self, y: np.ndarray, features: np.ndarray, label: int | str, clusterer: ClusterMixin
208+
) -> tuple[np.ndarray, np.ndarray]:
188209
"""
189210
Given a class label, it clusters all the feature representations that map to that class
190211
@@ -201,8 +222,9 @@ def _class_clustering(self, y: np.ndarray, features: np.ndarray, label: int | st
201222
cluster_labels = clusterer.fit_predict(selected_features)
202223
return cluster_labels, selected_indices
203224

204-
205-
def _calculate_features(self, feature_representation_model: Model, x: np.ndarray) -> np.ndarray:
225+
def _calculate_features_unwrapped(
226+
self, feature_representation_model: Model, x: np.ndarray
227+
) -> np.ndarray:
206228
"""
207229
Calculates the features using the first DNN slice
208230
@@ -212,44 +234,35 @@ def _calculate_features(self, feature_representation_model: Model, x: np.ndarray
212234
"""
213235
return feature_representation_model(x, training=False)
214236

215-
216-
def _feature_extraction(self, x_train: np.ndarray, feature_representation_model: Model) -> np.ndarray:
237+
def _feature_extraction(
238+
self, x_train: tf_types.data.Dataset, feature_representation_model: Model
239+
) -> np.ndarray:
217240
"""
218241
Extract features from the model using the feature representation sub model.
219242
220-
:param x_train: numpy array d-dimensional features for n data entries. Features are extracted from here
243+
:param x_train: Tensorflow dataset with features for n data entries. Features are extracted from here
221244
:param feature_representation_model: DNN submodel from input up to feature abstraction
222245
:return: features. numpy array of features
223246
"""
224247
# Convert data to TensorFlow tensors if needed
225-
data = x_train
226-
if not isinstance(x_train, self._tf_runtime.Tensor):
227-
data = self._tf_runtime.convert_to_tensor(x_train, dtype=self._tf_runtime.float32)
248+
features_np_list: list[np.ndarray] = []
228249

229-
# Process in batches to avoid memory issues
230-
batch_size = 256
231-
num_batches = int(np.ceil(len(data) / batch_size))
232-
features: list[tf_types.Tensor] = []
250+
# Iterate directly over the provided input_dataset
251+
for batch_data_tf in x_train:
252+
batch_features = self._calculate_features(feature_representation_model, batch_data_tf)
253+
features_np_list.append(batch_features.numpy()) # Convert to NumPy immediately
233254

234-
for i in range(num_batches):
235-
start_idx = i * batch_size
236-
end_idx = min((i + 1) * batch_size, len(data))
237-
batch = data[start_idx:end_idx]
238-
batch_features = self._calculate_features(feature_representation_model, batch)
239-
features.append(batch_features)
240-
241-
# Concatenate all batches
242-
final_features_tensor: tf_types.Tensor = self._tf_runtime.concat(features, axis=0)
243-
244-
return final_features_tensor.numpy()
255+
# Concatenate all batches of numpy arrays on CPU (system RAM)
256+
final_features_np: np.ndarray = np.concatenate(features_np_list, axis=0)
245257

258+
return final_features_np
246259

247260
def _cluster_classes(
248-
self,
249-
y_train: np.ndarray,
250-
unique_classes: set[int],
251-
features: np.ndarray,
252-
clusterer: ClusterMixin,
261+
self,
262+
y_train: np.ndarray,
263+
unique_classes: set[int],
264+
features: np.ndarray,
265+
clusterer: ClusterMixin,
253266
) -> tuple[np.ndarray, dict]:
254267
"""
255268
Clusters all the classes in the given dataset into uniquely identifiable clusters.
@@ -337,7 +350,9 @@ def _extract_submodels(self, final_feature_layer_name: str) -> tuple[Model, Mode
337350
classifier_submodel_layers = keras_model.layers[final_feature_layer_index + 1 :]
338351

339352
# Create the classifier submodel
340-
classifying_submodel = self._KerasSequential(classifier_submodel_layers, name="classifying_submodel")
353+
classifying_submodel = self._KerasSequential(
354+
classifier_submodel_layers, name="classifying_submodel"
355+
)
341356

342357
intermediate_shape = feature_representation_model.output_shape[1:]
343358
dummy_input = self._tf_runtime.zeros((1,) + intermediate_shape)
@@ -398,6 +413,7 @@ def _calculate_misclassification_rate(
398413
:param deviation: The deviation vector to apply
399414
:return: The misclassification rate (0.0 to 1.0)
400415
"""
416+
401417
def _predict_with_deviation_inner(features, deviation):
402418
# Add deviation to features and pass through ReLu to keep in latent space
403419
deviated_features = self._tf_runtime.nn.relu(features + deviation)
@@ -416,13 +432,13 @@ def _predict_with_deviation_inner(features, deviation):
416432
sample_features = self.feature_representation_model.predict(sample_data)
417433
feature_shape = sample_features.shape[1:]
418434

419-
# predict_with_deviation = self._tf_runtime.function(
420-
# _predict_with_deviation_inner,
421-
# input_signature=[
422-
# self._tf_runtime.TensorSpec(shape=[None, *feature_shape], dtype=self._tf_runtime.float32),
423-
# self._tf_runtime.TensorSpec(shape=deviation.shape, dtype=self._tf_runtime.float32),
424-
# ]
425-
# )
435+
predict_with_deviation = self._tf_runtime.function(
436+
_predict_with_deviation_inner,
437+
input_signature=[
438+
self._tf_runtime.TensorSpec(shape=[None, *feature_shape], dtype=self._tf_runtime.float32),
439+
self._tf_runtime.TensorSpec(shape=deviation.shape, dtype=self._tf_runtime.float32),
440+
]
441+
)
426442

427443
total_elements = 0
428444
misclassified_elements = 0
@@ -456,14 +472,18 @@ def _predict_with_deviation_inner(features, deviation):
456472
batch_data = other_class_data[start_idx:end_idx]
457473

458474
# Convert to tensor
459-
batch_data_tf = self._tf_runtime.convert_to_tensor(batch_data, dtype=self._tf_runtime.float32)
475+
batch_data_tf = self._tf_runtime.convert_to_tensor(
476+
batch_data, dtype=self._tf_runtime.float32
477+
)
460478

461479
# Extract features
462-
features = self._calculate_features(self.feature_representation_model, batch_data_tf)
480+
features = self._calculate_features(
481+
self.feature_representation_model, batch_data_tf
482+
)
463483
all_features.append(features)
464484

465485
# Get predictions with deviation
466-
predictions = _predict_with_deviation_inner(features, deviation_tf)
486+
predictions = predict_with_deviation(features, deviation_tf)
467487

468488
# Convert predictions to class indices
469489
pred_classes = self._tf_runtime.argmax(predictions, axis=1).numpy()
@@ -538,7 +558,9 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
538558
benign_class_indices = np.intersect1d(
539559
self.benign_indices, np.where(self.y_train == class_label)[0]
540560
)
541-
benign_centroids[class_label] = self._calculate_centroid(benign_class_indices, self.features)
561+
benign_centroids[class_label] = self._calculate_centroid(
562+
benign_class_indices, self.features
563+
)
542564

543565
logging.info("Calculating misclassification rates...")
544566
misclassification_rates = {}

0 commit comments

Comments
 (0)