25
25
26
26
import logging
27
27
import warnings
28
- from typing import TYPE_CHECKING , Any
28
+ from typing import TYPE_CHECKING , Any , Callable
29
29
30
30
import numpy as np
31
31
46
46
tf_types = Any
47
47
ClassifierType = Any
48
48
49
+
49
50
def _encode_labels (y : np .ndarray ) -> tuple [np .ndarray , set , np .ndarray , dict ]:
50
51
"""
51
52
Given the target column, it generates the label encoding and the reverse mapping to use in the
@@ -87,15 +88,17 @@ class ClusteringCentroidAnalysisTensorFlowV2(PoisonFilteringDefence):
87
88
valid_reduce = ["UMAP" ]
88
89
89
90
def __init__ (
90
- self ,
91
- classifier : CLASSIFIER_TYPE ,
92
- x_train : np .ndarray ,
93
- y_train : np .ndarray ,
94
- benign_indices : np .ndarray ,
95
- final_feature_layer_name : str ,
96
- misclassification_threshold : float ,
97
- reducer : UMAP | None = None ,
98
- clusterer : ClusterMixin | None = None ,
91
+ self ,
92
+ classifier : CLASSIFIER_TYPE ,
93
+ x_train : np .ndarray ,
94
+ y_train : np .ndarray ,
95
+ benign_indices : np .ndarray ,
96
+ final_feature_layer_name : str ,
97
+ misclassification_threshold : float ,
98
+ reducer : UMAP | None = None ,
99
+ clusterer : ClusterMixin | None = None ,
100
+ feature_extraction_batch_size : int = 32 ,
101
+ misclassification_batch_size : int = 32 ,
99
102
):
100
103
"""
101
104
Creates a :class: `ClusteringCentroidAnalysis` object for the given classifier
@@ -110,19 +113,19 @@ def __init__(
110
113
try :
111
114
import tensorflow as tf_runtime
112
115
from tensorflow .keras import Model , Sequential
116
+
113
117
self ._tf_runtime = tf_runtime
114
118
self ._KerasModel = Model
115
119
self ._KerasSequential = Sequential
116
120
self ._tf_runtime .get_logger ().setLevel (logging .WARN )
117
121
except ImportError as e :
118
- raise ImportError (
119
- "TensorFlow is required for ClusteringCentroidAnalysis. "
120
- ) from e
122
+ raise ImportError ("TensorFlow is required for ClusteringCentroidAnalysis. " ) from e
121
123
122
124
if clusterer is None :
123
125
try :
124
126
from sklearn .base import ClusterMixin
125
127
from sklearn .cluster import DBSCAN
128
+
126
129
self .clusterer = DBSCAN (eps = 0.8 , min_samples = 20 )
127
130
except ImportError as e :
128
131
raise ImportError (
@@ -132,6 +135,7 @@ def __init__(
132
135
if reducer is None :
133
136
try :
134
137
from umap import UMAP
138
+
135
139
self .reducer = UMAP (n_neighbors = 5 , min_dist = 0 )
136
140
except ImportError as e :
137
141
raise ImportError (
@@ -142,13 +146,30 @@ def __init__(
142
146
super ().__init__ (classifier , x_train , y_train )
143
147
self .benign_indices = benign_indices
144
148
(
145
- self .y_train ,
149
+ self .y_train_np ,
146
150
self .unique_classes ,
147
151
self .class_mapping ,
148
152
self .reverse_class_mapping ,
149
153
) = _encode_labels (y_train )
150
154
151
- self .x_benign , self .y_benign = self ._get_benign_data ()
155
+ # Data is loaded as NumPy arrays first
156
+ self .x_benign_np , self .y_benign_np = self ._get_benign_data ()
157
+
158
+ self .feature_extraction_batch_size = feature_extraction_batch_size
159
+ self .misclassification_batch_size = misclassification_batch_size
160
+
161
+ # Create x_train_dataset for feature extraction
162
+ # Ensure data type is float32 for TensorFlow
163
+ self .x_train_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
164
+ self .x_train .astype (self ._tf_runtime .float32 .as_numpy_dtype )
165
+ ).batch (self .feature_extraction_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
166
+
167
+ # Create x_benign_dataset for misclassification rate calculation
168
+ # Ensure data type is float32 for TensorFlow
169
+ self .x_benign_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
170
+ self .x_benign_np .astype (self ._tf_runtime .float32 .as_numpy_dtype )
171
+ ).batch (self .misclassification_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
172
+ # ---------------------------------------------
152
173
153
174
self .feature_representation_model , self .classifying_submodel = self ._extract_submodels (
154
175
final_feature_layer_name
@@ -157,19 +178,16 @@ def __init__(
157
178
self .misclassification_threshold = np .float64 (misclassification_threshold )
158
179
159
180
# Dynamic @tf.function wrapping
160
- self ._calculate_centroid_tf_original = self ._calculate_centroid_tf
161
- self ._calculate_features_original = self ._calculate_features
162
-
163
- self ._calculate_centroid_tf = self ._tf_runtime .function (self ._calculate_centroid_tf_original , reduce_retracing = True )
164
- self ._calculate_features = self ._tf_runtime .function (self ._calculate_features_original )
165
-
181
+ self ._calculate_centroid_tf = self ._tf_runtime .function (
182
+ self ._calculate_centroid_tf_unwrapped , reduce_retracing = True
183
+ )
184
+ self ._calculate_features = self ._tf_runtime .function (self ._calculate_features_unwrapped )
166
185
167
186
logging .info ("CCA object created successfully." )
168
187
169
- def _calculate_centroid_tf (self , features ):
188
+ def _calculate_centroid_tf_unwrapped (self , features ):
170
189
return self ._tf_runtime .reduce_mean (features , axis = 0 )
171
190
172
-
173
191
def _calculate_centroid (self , selected_indices : np .ndarray , features : np .ndarray ) -> np .ndarray :
174
192
"""
175
193
Returns the centroid of all data within a specific cluster that is classified as a specific class label
@@ -179,12 +197,15 @@ def _calculate_centroid(self, selected_indices: np.ndarray, features: np.ndarray
179
197
:return: d-dimensional numpy array
180
198
"""
181
199
selected_features = features [selected_indices ]
182
- features_tf = self ._tf_runtime .convert_to_tensor (selected_features , dtype = self ._tf_runtime .float32 )
200
+ features_tf = self ._tf_runtime .convert_to_tensor (
201
+ selected_features , dtype = self ._tf_runtime .float32
202
+ )
183
203
centroid = self ._calculate_centroid_tf (features_tf )
184
204
return centroid .numpy ()
185
205
186
-
187
- def _class_clustering (self , y : np .ndarray , features : np .ndarray , label : int | str , clusterer : ClusterMixin ) -> tuple [np .ndarray , np .ndarray ]:
206
+ def _class_clustering (
207
+ self , y : np .ndarray , features : np .ndarray , label : int | str , clusterer : ClusterMixin
208
+ ) -> tuple [np .ndarray , np .ndarray ]:
188
209
"""
189
210
Given a class label, it clusters all the feature representations that map to that class
190
211
@@ -201,8 +222,9 @@ def _class_clustering(self, y: np.ndarray, features: np.ndarray, label: int | st
201
222
cluster_labels = clusterer .fit_predict (selected_features )
202
223
return cluster_labels , selected_indices
203
224
204
-
205
- def _calculate_features (self , feature_representation_model : Model , x : np .ndarray ) -> np .ndarray :
225
+ def _calculate_features_unwrapped (
226
+ self , feature_representation_model : Model , x : np .ndarray
227
+ ) -> np .ndarray :
206
228
"""
207
229
Calculates the features using the first DNN slice
208
230
@@ -212,44 +234,35 @@ def _calculate_features(self, feature_representation_model: Model, x: np.ndarray
212
234
"""
213
235
return feature_representation_model (x , training = False )
214
236
215
-
216
- def _feature_extraction (self , x_train : np .ndarray , feature_representation_model : Model ) -> np .ndarray :
237
+ def _feature_extraction (
238
+ self , x_train : tf_types .data .Dataset , feature_representation_model : Model
239
+ ) -> np .ndarray :
217
240
"""
218
241
Extract features from the model using the feature representation sub model.
219
242
220
- :param x_train: numpy array d-dimensional features for n data entries. Features are extracted from here
243
+ :param x_train: Tensorflow dataset with features for n data entries. Features are extracted from here
221
244
:param feature_representation_model: DNN submodel from input up to feature abstraction
222
245
:return: features. numpy array of features
223
246
"""
224
247
# Convert data to TensorFlow tensors if needed
225
- data = x_train
226
- if not isinstance (x_train , self ._tf_runtime .Tensor ):
227
- data = self ._tf_runtime .convert_to_tensor (x_train , dtype = self ._tf_runtime .float32 )
248
+ features_np_list : list [np .ndarray ] = []
228
249
229
- # Process in batches to avoid memory issues
230
- batch_size = 256
231
- num_batches = int ( np . ceil ( len ( data ) / batch_size ) )
232
- features : list [ tf_types . Tensor ] = []
250
+ # Iterate directly over the provided input_dataset
251
+ for batch_data_tf in x_train :
252
+ batch_features = self . _calculate_features ( feature_representation_model , batch_data_tf )
253
+ features_np_list . append ( batch_features . numpy ()) # Convert to NumPy immediately
233
254
234
- for i in range (num_batches ):
235
- start_idx = i * batch_size
236
- end_idx = min ((i + 1 ) * batch_size , len (data ))
237
- batch = data [start_idx :end_idx ]
238
- batch_features = self ._calculate_features (feature_representation_model , batch )
239
- features .append (batch_features )
240
-
241
- # Concatenate all batches
242
- final_features_tensor : tf_types .Tensor = self ._tf_runtime .concat (features , axis = 0 )
243
-
244
- return final_features_tensor .numpy ()
255
+ # Concatenate all batches of numpy arrays on CPU (system RAM)
256
+ final_features_np : np .ndarray = np .concatenate (features_np_list , axis = 0 )
245
257
258
+ return final_features_np
246
259
247
260
def _cluster_classes (
248
- self ,
249
- y_train : np .ndarray ,
250
- unique_classes : set [int ],
251
- features : np .ndarray ,
252
- clusterer : ClusterMixin ,
261
+ self ,
262
+ y_train : np .ndarray ,
263
+ unique_classes : set [int ],
264
+ features : np .ndarray ,
265
+ clusterer : ClusterMixin ,
253
266
) -> tuple [np .ndarray , dict ]:
254
267
"""
255
268
Clusters all the classes in the given dataset into uniquely identifiable clusters.
@@ -337,7 +350,9 @@ def _extract_submodels(self, final_feature_layer_name: str) -> tuple[Model, Mode
337
350
classifier_submodel_layers = keras_model .layers [final_feature_layer_index + 1 :]
338
351
339
352
# Create the classifier submodel
340
- classifying_submodel = self ._KerasSequential (classifier_submodel_layers , name = "classifying_submodel" )
353
+ classifying_submodel = self ._KerasSequential (
354
+ classifier_submodel_layers , name = "classifying_submodel"
355
+ )
341
356
342
357
intermediate_shape = feature_representation_model .output_shape [1 :]
343
358
dummy_input = self ._tf_runtime .zeros ((1 ,) + intermediate_shape )
@@ -398,6 +413,7 @@ def _calculate_misclassification_rate(
398
413
:param deviation: The deviation vector to apply
399
414
:return: The misclassification rate (0.0 to 1.0)
400
415
"""
416
+
401
417
def _predict_with_deviation_inner (features , deviation ):
402
418
# Add deviation to features and pass through ReLu to keep in latent space
403
419
deviated_features = self ._tf_runtime .nn .relu (features + deviation )
@@ -416,13 +432,13 @@ def _predict_with_deviation_inner(features, deviation):
416
432
sample_features = self .feature_representation_model .predict (sample_data )
417
433
feature_shape = sample_features .shape [1 :]
418
434
419
- # predict_with_deviation = self._tf_runtime.function(
420
- # _predict_with_deviation_inner,
421
- # input_signature=[
422
- # self._tf_runtime.TensorSpec(shape=[None, *feature_shape], dtype=self._tf_runtime.float32),
423
- # self._tf_runtime.TensorSpec(shape=deviation.shape, dtype=self._tf_runtime.float32),
424
- # ]
425
- # )
435
+ predict_with_deviation = self ._tf_runtime .function (
436
+ _predict_with_deviation_inner ,
437
+ input_signature = [
438
+ self ._tf_runtime .TensorSpec (shape = [None , * feature_shape ], dtype = self ._tf_runtime .float32 ),
439
+ self ._tf_runtime .TensorSpec (shape = deviation .shape , dtype = self ._tf_runtime .float32 ),
440
+ ]
441
+ )
426
442
427
443
total_elements = 0
428
444
misclassified_elements = 0
@@ -456,14 +472,18 @@ def _predict_with_deviation_inner(features, deviation):
456
472
batch_data = other_class_data [start_idx :end_idx ]
457
473
458
474
# Convert to tensor
459
- batch_data_tf = self ._tf_runtime .convert_to_tensor (batch_data , dtype = self ._tf_runtime .float32 )
475
+ batch_data_tf = self ._tf_runtime .convert_to_tensor (
476
+ batch_data , dtype = self ._tf_runtime .float32
477
+ )
460
478
461
479
# Extract features
462
- features = self ._calculate_features (self .feature_representation_model , batch_data_tf )
480
+ features = self ._calculate_features (
481
+ self .feature_representation_model , batch_data_tf
482
+ )
463
483
all_features .append (features )
464
484
465
485
# Get predictions with deviation
466
- predictions = _predict_with_deviation_inner (features , deviation_tf )
486
+ predictions = predict_with_deviation (features , deviation_tf )
467
487
468
488
# Convert predictions to class indices
469
489
pred_classes = self ._tf_runtime .argmax (predictions , axis = 1 ).numpy ()
@@ -538,7 +558,9 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
538
558
benign_class_indices = np .intersect1d (
539
559
self .benign_indices , np .where (self .y_train == class_label )[0 ]
540
560
)
541
- benign_centroids [class_label ] = self ._calculate_centroid (benign_class_indices , self .features )
561
+ benign_centroids [class_label ] = self ._calculate_centroid (
562
+ benign_class_indices , self .features
563
+ )
542
564
543
565
logging .info ("Calculating misclassification rates..." )
544
566
misclassification_rates = {}
0 commit comments