@@ -109,11 +109,21 @@ def __init__(
109
109
:param benign_indices: array data points' indices known to be benign
110
110
:param final_feature_layer_name: the name of the final layer that builds feature representation. Must
111
111
be a ReLu layer
112
+ :param misclassification_threshold: maximum misclassification threshold to consider a cluster as clean
113
+ :param reducer: dimensionality reducer used to reduce feature space. UMAP is used as default
114
+ :param clusterer: clustering algorithm used to cluster the features. DBSCAN is used as default
115
+ :param feature_extraction_batch_size: batch size for feature extraction.
116
+ Use lower values in case of low GPU memory availability
117
+ :param misclassification_batch_size: batch size for misclassification.
118
+ Use lower values in case of low GPU memory
112
119
"""
120
+ # Tensorflow is imported dynamically to keep ART framework-agnostic
113
121
try :
114
122
import tensorflow as tf_runtime
115
123
from tensorflow .keras import Model , Sequential
116
124
125
+ # Keep a runtime and function attributes to prevent garbage collection on these during the
126
+ # algorithm lifespan
117
127
self ._tf_runtime = tf_runtime
118
128
self ._KerasModel = Model
119
129
self ._KerasSequential = Sequential
@@ -126,6 +136,7 @@ def __init__(
126
136
from sklearn .base import ClusterMixin
127
137
from sklearn .cluster import DBSCAN
128
138
139
+ # Default clusterer, recommended by original authors
129
140
self .clusterer = DBSCAN (eps = 0.8 , min_samples = 20 )
130
141
except ImportError as e :
131
142
raise ImportError (
@@ -136,6 +147,7 @@ def __init__(
136
147
try :
137
148
from umap import UMAP
138
149
150
+ # Default reducer, recommended by original authors
139
151
self .reducer = UMAP (n_neighbors = 5 , min_dist = 0 )
140
152
except ImportError as e :
141
153
raise ImportError (
@@ -158,53 +170,77 @@ def __init__(
158
170
self .feature_extraction_batch_size = feature_extraction_batch_size
159
171
self .misclassification_batch_size = misclassification_batch_size
160
172
173
+ # tf.data.Datasets are used as these are more efficient for TF-specific operations
161
174
# Create x_train_dataset for feature extraction
162
175
# Ensure data type is float32 for TensorFlow
163
- self .x_train_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
164
- self .x_train .astype (self ._tf_runtime .float32 .as_numpy_dtype )
165
- ).batch (self .feature_extraction_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
176
+ self .x_train_dataset = (
177
+ self ._tf_runtime .data .Dataset .from_tensor_slices (
178
+ self .x_train .astype (self ._tf_runtime .float32 .as_numpy_dtype )
179
+ )
180
+ .batch (self .feature_extraction_batch_size )
181
+ .prefetch (self ._tf_runtime .data .AUTOTUNE )
182
+ )
166
183
167
184
# Create x_benign_dataset for misclassification rate calculation
168
185
# Ensure data type is float32 for TensorFlow
169
- self .x_benign_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
170
- self .x_benign_np .astype (self ._tf_runtime .float32 .as_numpy_dtype )
171
- ).batch (self .misclassification_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
172
- # ---------------------------------------------
186
+ self .x_benign_dataset = (
187
+ self ._tf_runtime .data .Dataset .from_tensor_slices (
188
+ self .x_benign_np .astype (self ._tf_runtime .float32 .as_numpy_dtype )
189
+ )
190
+ .batch (self .misclassification_batch_size )
191
+ .prefetch (self ._tf_runtime .data .AUTOTUNE )
192
+ )
173
193
174
194
self .feature_representation_model , self .classifying_submodel = self ._extract_submodels (
175
195
final_feature_layer_name
176
196
)
177
197
178
198
self .misclassification_threshold = np .float64 (misclassification_threshold )
179
199
180
- # --- Get feature_shape and deviation_shape BEFORE wrapping functions ---
181
- if hasattr (self .feature_representation_model , 'output_shape' ) and self .feature_representation_model .output_shape [1 :] is not None :
200
+ # Get feature_shape and deviation_shape before wrapping functions
201
+ # It's fundamental to wrap functions with input signatures
202
+ if (
203
+ hasattr (self .feature_representation_model , "output_shape" )
204
+ and self .feature_representation_model .output_shape [1 :] is not None
205
+ ):
182
206
# Assumes output_shape is something like (None, 5) or (None, 28, 28, 3)
183
207
self .feature_output_shape = self .feature_representation_model .output_shape [1 :]
184
208
# Handle case where output_shape[1:] might be an empty tuple if model outputs a scalar.
185
209
# Unlikely for feature extractor, but good to be robust.
186
210
if not self .feature_output_shape :
187
- logging .warning ("feature_representation_model.output_shape[1:] is empty. Attempting dummy inference." )
211
+ logging .warning (
212
+ "feature_representation_model.output_shape[1:] is empty. Attempting dummy inference."
213
+ )
188
214
# Fallback to dummy inference if output_shape is unexpectedly empty
189
- dummy_input_shape = (1 , * self .x_train .shape [1 :]) # Use actual classifier input shape
190
- dummy_input = self ._tf_runtime .random .uniform (shape = dummy_input_shape , dtype = self ._tf_runtime .float32 )
215
+ dummy_input_shape = (
216
+ 1 ,
217
+ * self .x_train .shape [1 :],
218
+ ) # Use actual classifier input shape
219
+ dummy_input = self ._tf_runtime .random .uniform (
220
+ shape = dummy_input_shape , dtype = self ._tf_runtime .float32
221
+ )
191
222
dummy_features = self .feature_representation_model (dummy_input )
192
223
self .feature_output_shape = dummy_features .shape [1 :]
193
224
else :
194
225
# Fallback if output_shape attribute is not present or usable
195
- logging .warning ("feature_representation_model.output_shape not directly available. Performing dummy inference for shape." )
196
- dummy_input_shape = (1 , * self .x_train .shape [1 :]) # Use actual classifier input shape
197
- dummy_input = self ._tf_runtime .random .uniform (shape = dummy_input_shape , dtype = self ._tf_runtime .float32 )
226
+ logging .warning (
227
+ "feature_representation_model.output_shape not directly available. "
228
+ "Performing dummy inference for shape."
229
+ )
230
+ dummy_input_shape = (1 , * self .x_train .shape [1 :]) # Use actual classifier input shape
231
+ dummy_input = self ._tf_runtime .random .uniform (
232
+ shape = dummy_input_shape , dtype = self ._tf_runtime .float32
233
+ )
198
234
dummy_features = self .feature_representation_model (dummy_input )
199
235
self .feature_output_shape = dummy_features .shape [1 :]
200
236
201
237
# If after all attempts, feature_output_shape is still empty/invalid, raise an error
202
238
if not self .feature_output_shape or any (d is None for d in self .feature_output_shape ):
203
239
raise RuntimeError (
204
- f"Could not determine a valid feature_output_shape ({ self .feature_output_shape } ) for tf.function input_signature. "
240
+ f"Could not determine a valid feature_output_shape ({ self .feature_output_shape } ) for "
241
+ f"tf.function input_signature. "
205
242
"Ensure feature_representation_model is built and outputs a known shape."
206
243
)
207
- # --- END SHAPE DETERMINATION ---
208
244
209
245
# Dynamic @tf.function wrapping
210
246
self ._calculate_centroid_tf = self ._tf_runtime .function (
@@ -215,8 +251,12 @@ def __init__(
215
251
self ._predict_with_deviation = self ._tf_runtime .function (
216
252
self ._predict_with_deviation_unwrapped ,
217
253
input_signature = [
218
- self ._tf_runtime .TensorSpec (shape = [None , * self .feature_output_shape ], dtype = self ._tf_runtime .float32 ),
219
- self ._tf_runtime .TensorSpec (shape = self .feature_output_shape , dtype = self ._tf_runtime .float32 ),
254
+ self ._tf_runtime .TensorSpec (
255
+ shape = [None , * self .feature_output_shape ], dtype = self ._tf_runtime .float32
256
+ ),
257
+ self ._tf_runtime .TensorSpec (
258
+ shape = self .feature_output_shape , dtype = self ._tf_runtime .float32
259
+ ),
220
260
],
221
261
reduce_retracing = True ,
222
262
)
@@ -288,7 +328,7 @@ def _feature_extraction(
288
328
# Iterate directly over the provided input_dataset
289
329
for batch_data_tf in x_train :
290
330
batch_features = self ._calculate_features (feature_representation_model , batch_data_tf )
291
- features_np_list .append (batch_features .numpy ()) # Convert to NumPy immediately
331
+ features_np_list .append (batch_features .numpy ()) # Convert to NumPy immediately
292
332
293
333
# Concatenate all batches of numpy arrays on CPU (system RAM)
294
334
final_features_np : np .ndarray = np .concatenate (features_np_list , axis = 0 )
@@ -441,10 +481,10 @@ def evaluate_defence(self, is_clean: np.ndarray, **kwargs) -> str:
441
481
442
482
return confusion_matrix_json
443
483
444
- def _predict_with_deviation_unwrapped (self , features : tf_types .Tensor , deviation : tf_types .Tensor ) -> tf_types .Tensor :
445
- # Add deviation to features and pass through ReLu to keep in latent space
484
+ def _predict_with_deviation_unwrapped (
485
+ self , features : tf_types .Tensor , deviation : tf_types .Tensor
486
+ ) -> tf_types .Tensor :
446
487
deviated_features = self ._tf_runtime .nn .relu (features + deviation )
447
- # Get predictions from classifying submodel
448
488
return self .classifying_submodel (deviated_features , training = False )
449
489
450
490
def _calculate_misclassification_rate (
@@ -462,7 +502,6 @@ def _calculate_misclassification_rate(
462
502
463
503
total_elements = 0
464
504
misclassified_elements = 0
465
- all_features_np : list [np .ndarray ] = []
466
505
467
506
# Get all classes except the current one
468
507
other_classes = self .unique_classes - {class_label }
@@ -476,27 +515,27 @@ def _calculate_misclassification_rate(
476
515
if len (other_class_data ) == 0 :
477
516
continue
478
517
479
- class_x_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
480
- other_class_data .astype (self ._tf_runtime .float32 .as_numpy_dtype )
481
- ).batch (self .misclassification_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
518
+ class_x_dataset = (
519
+ self ._tf_runtime .data .Dataset .from_tensor_slices (
520
+ other_class_data .astype (self ._tf_runtime .float32 .as_numpy_dtype )
521
+ )
522
+ .batch (self .misclassification_batch_size )
523
+ .prefetch (self ._tf_runtime .data .AUTOTUNE )
524
+ )
482
525
483
526
total_elements += len (other_class_data )
484
527
class_misclassified = 0
485
528
529
+ # Batches of the class are processed with deviation to determine misclassification
486
530
for batch_data_tf in class_x_dataset :
487
- # Extract features
488
531
features_tf = self ._calculate_features (
489
532
self .feature_representation_model , batch_data_tf
490
533
)
491
- all_features_np .append (features_tf .numpy ())
492
534
493
- # Get predictions with deviation
494
535
predictions = self ._predict_with_deviation (features_tf , deviation_tf )
495
536
496
- # Convert predictions to class indices
497
537
pred_classes_np = self ._tf_runtime .argmax (predictions , axis = 1 ).numpy ()
498
538
499
- # Count misclassifications (predicted as class_label)
500
539
batch_misclassified = np .sum (pred_classes_np == class_label )
501
540
class_misclassified += batch_misclassified
502
541
@@ -506,11 +545,9 @@ def _calculate_misclassification_rate(
506
545
if total_elements == 0 :
507
546
return np .float64 (0.0 )
508
547
509
- all_f_vectors_np = np .concatenate (all_features_np , axis = 0 )
510
548
logging .debug (
511
- "MR --> %s , |f| = %s : %s / %s = %s" ,
549
+ "MR --> %s : %s / %s = %s" ,
512
550
class_label ,
513
- np .linalg .norm (np .mean (all_f_vectors_np , axis = 0 )),
514
551
misclassified_elements ,
515
552
total_elements ,
516
553
np .float64 (misclassified_elements ) / np .float64 (total_elements ),
@@ -519,14 +556,16 @@ def _calculate_misclassification_rate(
519
556
return np .float64 (misclassified_elements ) / np .float64 (total_elements )
520
557
521
558
def detect_poison (self , ** kwargs ) -> tuple [dict , list [int ]]:
522
- # saves important information about the algorithm execution for further analysis
559
+ # Saves important information about the algorithm execution for further analysis
523
560
report : dict [str , Any ] = {}
524
561
525
562
self .is_clean_np = np .ones (len (self .y_train ))
526
563
527
- self .features = self ._feature_extraction (self .x_train_dataset , self .feature_representation_model )
564
+ self .features = self ._feature_extraction (
565
+ self .x_train_dataset , self .feature_representation_model
566
+ )
528
567
529
- # FIXME: temporal fix to test other layers
568
+ # Small fix to add flexibility to use CCAUD in non-flattened scenarios. Not recommended, but available
530
569
if len (self .features .shape ) > 2 :
531
570
num_samples = self .features .shape [0 ]
532
571
self .features = self .features .reshape (num_samples , - 1 ) # Flattening
0 commit comments