Skip to content

Commit e56b47a

Browse files
committed
Batch processing + notebook corrections
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent c1e3f5d commit e56b47a

File tree

3 files changed

+2196
-2809
lines changed

3 files changed

+2196
-2809
lines changed

art/defences/detector/poison/clustering_centroid_analysis.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import logging
2727
import warnings
28-
from typing import TYPE_CHECKING, Any, Callable
28+
from typing import TYPE_CHECKING, Any
2929

3030
import numpy as np
3131

@@ -177,12 +177,50 @@ def __init__(
177177

178178
self.misclassification_threshold = np.float64(misclassification_threshold)
179179

180+
# --- Get feature_shape and deviation_shape BEFORE wrapping functions ---
181+
if hasattr(self.feature_representation_model, 'output_shape') and self.feature_representation_model.output_shape[1:] is not None:
182+
# Assumes output_shape is something like (None, 5) or (None, 28, 28, 3)
183+
self.feature_output_shape = self.feature_representation_model.output_shape[1:]
184+
# Handle case where output_shape[1:] might be an empty tuple if model outputs a scalar.
185+
# Unlikely for feature extractor, but good to be robust.
186+
if not self.feature_output_shape:
187+
logging.warning("feature_representation_model.output_shape[1:] is empty. Attempting dummy inference.")
188+
# Fallback to dummy inference if output_shape is unexpectedly empty
189+
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
190+
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
191+
dummy_features = self.feature_representation_model(dummy_input)
192+
self.feature_output_shape = dummy_features.shape[1:]
193+
else:
194+
# Fallback if output_shape attribute is not present or usable
195+
logging.warning("feature_representation_model.output_shape not directly available. Performing dummy inference for shape.")
196+
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
197+
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
198+
dummy_features = self.feature_representation_model(dummy_input)
199+
self.feature_output_shape = dummy_features.shape[1:]
200+
201+
# If after all attempts, feature_output_shape is still empty/invalid, raise an error
202+
if not self.feature_output_shape or any(d is None for d in self.feature_output_shape):
203+
raise RuntimeError(
204+
f"Could not determine a valid feature_output_shape ({self.feature_output_shape}) for tf.function input_signature. "
205+
"Ensure feature_representation_model is built and outputs a known shape."
206+
)
207+
# --- END SHAPE DETERMINATION ---
208+
180209
# Dynamic @tf.function wrapping
181210
self._calculate_centroid_tf = self._tf_runtime.function(
182211
self._calculate_centroid_tf_unwrapped, reduce_retracing=True
183212
)
184213
self._calculate_features = self._tf_runtime.function(self._calculate_features_unwrapped)
185214

215+
self._predict_with_deviation = self._tf_runtime.function(
216+
self._predict_with_deviation_unwrapped,
217+
input_signature=[
218+
self._tf_runtime.TensorSpec(shape=[None, *self.feature_output_shape], dtype=self._tf_runtime.float32),
219+
self._tf_runtime.TensorSpec(shape=self.feature_output_shape, dtype=self._tf_runtime.float32),
220+
],
221+
reduce_retracing=True,
222+
)
223+
186224
logging.info("CCA object created successfully.")
187225

188226
def _calculate_centroid_tf_unwrapped(self, features):
@@ -403,6 +441,12 @@ def evaluate_defence(self, is_clean: np.ndarray, **kwargs) -> str:
403441

404442
return confusion_matrix_json
405443

444+
def _predict_with_deviation_unwrapped(self, features: tf_types.Tensor, deviation: tf_types.Tensor) -> tf_types.Tensor:
445+
# Add deviation to features and pass through ReLu to keep in latent space
446+
deviated_features = self._tf_runtime.nn.relu(features + deviation)
447+
# Get predictions from classifying submodel
448+
return self.classifying_submodel(deviated_features, training=False)
449+
406450
def _calculate_misclassification_rate(
407451
self, class_label: int, deviation: np.ndarray
408452
) -> np.float64:
@@ -413,83 +457,47 @@ def _calculate_misclassification_rate(
413457
:param deviation: The deviation vector to apply
414458
:return: The misclassification rate (0.0 to 1.0)
415459
"""
416-
417-
def _predict_with_deviation_inner(features, deviation):
418-
# Add deviation to features and pass through ReLu to keep in latent space
419-
deviated_features = self._tf_runtime.nn.relu(features + deviation)
420-
# Get predictions from classifying submodel
421-
predictions = self.classifying_submodel(deviated_features, training=False)
422-
return predictions
423-
424460
# Convert deviation to a tensor once
425461
deviation_tf = self._tf_runtime.convert_to_tensor(deviation, dtype=self._tf_runtime.float32)
426462

427-
# Get a sample to determine the input shape
428-
sample_data = self.x_benign[0:1]
429-
430-
# The feature shape depends on the feature_representation_model output
431-
# We need to run once to get the output shape
432-
sample_features = self.feature_representation_model.predict(sample_data)
433-
feature_shape = sample_features.shape[1:]
434-
435-
predict_with_deviation = self._tf_runtime.function(
436-
_predict_with_deviation_inner,
437-
input_signature=[
438-
self._tf_runtime.TensorSpec(shape=[None, *feature_shape], dtype=self._tf_runtime.float32),
439-
self._tf_runtime.TensorSpec(shape=deviation.shape, dtype=self._tf_runtime.float32),
440-
]
441-
)
442-
443463
total_elements = 0
444464
misclassified_elements = 0
465+
all_features_np: list[np.ndarray] = []
445466

446467
# Get all classes except the current one
447468
other_classes = self.unique_classes - {class_label}
448469

449-
all_features = []
450-
451470
# Process each class
452471
for other_class_label in other_classes:
453472
# Get data for this class
454-
other_class_mask = self.y_benign == other_class_label
455-
other_class_data = self.x_benign[other_class_mask]
473+
other_class_mask = self.y_benign_np == other_class_label
474+
other_class_data = self.x_benign_np[other_class_mask]
456475

457476
if len(other_class_data) == 0:
458477
continue
459478

460-
total_elements += len(other_class_data)
461-
462-
# Process in batches to avoid memory issues
463-
batch_size = 128 # Adjust based on your GPU memory
464-
num_samples = len(other_class_data)
465-
num_batches = int(np.ceil(num_samples / batch_size))
479+
class_x_dataset = self._tf_runtime.data.Dataset.from_tensor_slices(
480+
other_class_data.astype(self._tf_runtime.float32.as_numpy_dtype)
481+
).batch(self.misclassification_batch_size).prefetch(self._tf_runtime.data.AUTOTUNE)
466482

483+
total_elements += len(other_class_data)
467484
class_misclassified = 0
468485

469-
for i in range(num_batches):
470-
start_idx = i * batch_size
471-
end_idx = min((i + 1) * batch_size, num_samples)
472-
batch_data = other_class_data[start_idx:end_idx]
473-
474-
# Convert to tensor
475-
batch_data_tf = self._tf_runtime.convert_to_tensor(
476-
batch_data, dtype=self._tf_runtime.float32
477-
)
478-
486+
for batch_data_tf in class_x_dataset:
479487
# Extract features
480-
features = self._calculate_features(
488+
features_tf = self._calculate_features(
481489
self.feature_representation_model, batch_data_tf
482490
)
483-
all_features.append(features)
491+
all_features_np.append(features_tf.numpy())
484492

485493
# Get predictions with deviation
486-
predictions = predict_with_deviation(features, deviation_tf)
494+
predictions = self._predict_with_deviation(features_tf, deviation_tf)
487495

488496
# Convert predictions to class indices
489-
pred_classes = self._tf_runtime.argmax(predictions, axis=1).numpy()
497+
pred_classes_np = self._tf_runtime.argmax(predictions, axis=1).numpy()
490498

491499
# Count misclassifications (predicted as class_label)
492-
batch_misclassified = np.sum(pred_classes == class_label)
500+
batch_misclassified = np.sum(pred_classes_np == class_label)
493501
class_misclassified += batch_misclassified
494502

495503
misclassified_elements += class_misclassified
@@ -498,7 +506,7 @@ def _predict_with_deviation_inner(features, deviation):
498506
if total_elements == 0:
499507
return np.float64(0.0)
500508

501-
all_f_vectors_np = np.concatenate(all_features, axis=0)
509+
all_f_vectors_np = np.concatenate(all_features_np, axis=0)
502510
logging.debug(
503511
"MR --> %s , |f| = %s: %s / %s = %s",
504512
class_label,
@@ -516,7 +524,7 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
516524

517525
self.is_clean_np = np.ones(len(self.y_train))
518526

519-
self.features = self._feature_extraction(self.x_train, self.feature_representation_model)
527+
self.features = self._feature_extraction(self.x_train_dataset, self.feature_representation_model)
520528

521529
# FIXME: temporal fix to test other layers
522530
if len(self.features.shape) > 2:

0 commit comments

Comments
 (0)