25
25
26
26
import logging
27
27
import warnings
28
- from typing import TYPE_CHECKING , Any , Callable
28
+ from typing import TYPE_CHECKING , Any
29
29
30
30
import numpy as np
31
31
@@ -177,12 +177,50 @@ def __init__(
177
177
178
178
self .misclassification_threshold = np .float64 (misclassification_threshold )
179
179
180
+ # --- Get feature_shape and deviation_shape BEFORE wrapping functions ---
181
+ if hasattr (self .feature_representation_model , 'output_shape' ) and self .feature_representation_model .output_shape [1 :] is not None :
182
+ # Assumes output_shape is something like (None, 5) or (None, 28, 28, 3)
183
+ self .feature_output_shape = self .feature_representation_model .output_shape [1 :]
184
+ # Handle case where output_shape[1:] might be an empty tuple if model outputs a scalar.
185
+ # Unlikely for feature extractor, but good to be robust.
186
+ if not self .feature_output_shape :
187
+ logging .warning ("feature_representation_model.output_shape[1:] is empty. Attempting dummy inference." )
188
+ # Fallback to dummy inference if output_shape is unexpectedly empty
189
+ dummy_input_shape = (1 , * self .x_train .shape [1 :]) # Use actual classifier input shape
190
+ dummy_input = self ._tf_runtime .random .uniform (shape = dummy_input_shape , dtype = self ._tf_runtime .float32 )
191
+ dummy_features = self .feature_representation_model (dummy_input )
192
+ self .feature_output_shape = dummy_features .shape [1 :]
193
+ else :
194
+ # Fallback if output_shape attribute is not present or usable
195
+ logging .warning ("feature_representation_model.output_shape not directly available. Performing dummy inference for shape." )
196
+ dummy_input_shape = (1 , * self .x_train .shape [1 :]) # Use actual classifier input shape
197
+ dummy_input = self ._tf_runtime .random .uniform (shape = dummy_input_shape , dtype = self ._tf_runtime .float32 )
198
+ dummy_features = self .feature_representation_model (dummy_input )
199
+ self .feature_output_shape = dummy_features .shape [1 :]
200
+
201
+ # If after all attempts, feature_output_shape is still empty/invalid, raise an error
202
+ if not self .feature_output_shape or any (d is None for d in self .feature_output_shape ):
203
+ raise RuntimeError (
204
+ f"Could not determine a valid feature_output_shape ({ self .feature_output_shape } ) for tf.function input_signature. "
205
+ "Ensure feature_representation_model is built and outputs a known shape."
206
+ )
207
+ # --- END SHAPE DETERMINATION ---
208
+
180
209
# Dynamic @tf.function wrapping
181
210
self ._calculate_centroid_tf = self ._tf_runtime .function (
182
211
self ._calculate_centroid_tf_unwrapped , reduce_retracing = True
183
212
)
184
213
self ._calculate_features = self ._tf_runtime .function (self ._calculate_features_unwrapped )
185
214
215
+ self ._predict_with_deviation = self ._tf_runtime .function (
216
+ self ._predict_with_deviation_unwrapped ,
217
+ input_signature = [
218
+ self ._tf_runtime .TensorSpec (shape = [None , * self .feature_output_shape ], dtype = self ._tf_runtime .float32 ),
219
+ self ._tf_runtime .TensorSpec (shape = self .feature_output_shape , dtype = self ._tf_runtime .float32 ),
220
+ ],
221
+ reduce_retracing = True ,
222
+ )
223
+
186
224
logging .info ("CCA object created successfully." )
187
225
188
226
def _calculate_centroid_tf_unwrapped (self , features ):
@@ -403,6 +441,12 @@ def evaluate_defence(self, is_clean: np.ndarray, **kwargs) -> str:
403
441
404
442
return confusion_matrix_json
405
443
444
+ def _predict_with_deviation_unwrapped (self , features : tf_types .Tensor , deviation : tf_types .Tensor ) -> tf_types .Tensor :
445
+ # Add deviation to features and pass through ReLu to keep in latent space
446
+ deviated_features = self ._tf_runtime .nn .relu (features + deviation )
447
+ # Get predictions from classifying submodel
448
+ return self .classifying_submodel (deviated_features , training = False )
449
+
406
450
def _calculate_misclassification_rate (
407
451
self , class_label : int , deviation : np .ndarray
408
452
) -> np .float64 :
@@ -413,83 +457,47 @@ def _calculate_misclassification_rate(
413
457
:param deviation: The deviation vector to apply
414
458
:return: The misclassification rate (0.0 to 1.0)
415
459
"""
416
-
417
- def _predict_with_deviation_inner (features , deviation ):
418
- # Add deviation to features and pass through ReLu to keep in latent space
419
- deviated_features = self ._tf_runtime .nn .relu (features + deviation )
420
- # Get predictions from classifying submodel
421
- predictions = self .classifying_submodel (deviated_features , training = False )
422
- return predictions
423
-
424
460
# Convert deviation to a tensor once
425
461
deviation_tf = self ._tf_runtime .convert_to_tensor (deviation , dtype = self ._tf_runtime .float32 )
426
462
427
- # Get a sample to determine the input shape
428
- sample_data = self .x_benign [0 :1 ]
429
-
430
- # The feature shape depends on the feature_representation_model output
431
- # We need to run once to get the output shape
432
- sample_features = self .feature_representation_model .predict (sample_data )
433
- feature_shape = sample_features .shape [1 :]
434
-
435
- predict_with_deviation = self ._tf_runtime .function (
436
- _predict_with_deviation_inner ,
437
- input_signature = [
438
- self ._tf_runtime .TensorSpec (shape = [None , * feature_shape ], dtype = self ._tf_runtime .float32 ),
439
- self ._tf_runtime .TensorSpec (shape = deviation .shape , dtype = self ._tf_runtime .float32 ),
440
- ]
441
- )
442
-
443
463
total_elements = 0
444
464
misclassified_elements = 0
465
+ all_features_np : list [np .ndarray ] = []
445
466
446
467
# Get all classes except the current one
447
468
other_classes = self .unique_classes - {class_label }
448
469
449
- all_features = []
450
-
451
470
# Process each class
452
471
for other_class_label in other_classes :
453
472
# Get data for this class
454
- other_class_mask = self .y_benign == other_class_label
455
- other_class_data = self .x_benign [other_class_mask ]
473
+ other_class_mask = self .y_benign_np == other_class_label
474
+ other_class_data = self .x_benign_np [other_class_mask ]
456
475
457
476
if len (other_class_data ) == 0 :
458
477
continue
459
478
460
- total_elements += len (other_class_data )
461
-
462
- # Process in batches to avoid memory issues
463
- batch_size = 128 # Adjust based on your GPU memory
464
- num_samples = len (other_class_data )
465
- num_batches = int (np .ceil (num_samples / batch_size ))
479
+ class_x_dataset = self ._tf_runtime .data .Dataset .from_tensor_slices (
480
+ other_class_data .astype (self ._tf_runtime .float32 .as_numpy_dtype )
481
+ ).batch (self .misclassification_batch_size ).prefetch (self ._tf_runtime .data .AUTOTUNE )
466
482
483
+ total_elements += len (other_class_data )
467
484
class_misclassified = 0
468
485
469
- for i in range (num_batches ):
470
- start_idx = i * batch_size
471
- end_idx = min ((i + 1 ) * batch_size , num_samples )
472
- batch_data = other_class_data [start_idx :end_idx ]
473
-
474
- # Convert to tensor
475
- batch_data_tf = self ._tf_runtime .convert_to_tensor (
476
- batch_data , dtype = self ._tf_runtime .float32
477
- )
478
-
486
+ for batch_data_tf in class_x_dataset :
479
487
# Extract features
480
- features = self ._calculate_features (
488
+ features_tf = self ._calculate_features (
481
489
self .feature_representation_model , batch_data_tf
482
490
)
483
- all_features .append (features )
491
+ all_features_np .append (features_tf . numpy () )
484
492
485
493
# Get predictions with deviation
486
- predictions = predict_with_deviation ( features , deviation_tf )
494
+ predictions = self . _predict_with_deviation ( features_tf , deviation_tf )
487
495
488
496
# Convert predictions to class indices
489
- pred_classes = self ._tf_runtime .argmax (predictions , axis = 1 ).numpy ()
497
+ pred_classes_np = self ._tf_runtime .argmax (predictions , axis = 1 ).numpy ()
490
498
491
499
# Count misclassifications (predicted as class_label)
492
- batch_misclassified = np .sum (pred_classes == class_label )
500
+ batch_misclassified = np .sum (pred_classes_np == class_label )
493
501
class_misclassified += batch_misclassified
494
502
495
503
misclassified_elements += class_misclassified
@@ -498,7 +506,7 @@ def _predict_with_deviation_inner(features, deviation):
498
506
if total_elements == 0 :
499
507
return np .float64 (0.0 )
500
508
501
- all_f_vectors_np = np .concatenate (all_features , axis = 0 )
509
+ all_f_vectors_np = np .concatenate (all_features_np , axis = 0 )
502
510
logging .debug (
503
511
"MR --> %s , |f| = %s: %s / %s = %s" ,
504
512
class_label ,
@@ -516,7 +524,7 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
516
524
517
525
self .is_clean_np = np .ones (len (self .y_train ))
518
526
519
- self .features = self ._feature_extraction (self .x_train , self .feature_representation_model )
527
+ self .features = self ._feature_extraction (self .x_train_dataset , self .feature_representation_model )
520
528
521
529
# FIXME: temporal fix to test other layers
522
530
if len (self .features .shape ) > 2 :
0 commit comments