Skip to content

Commit fb926b2

Browse files
committed
Black reformatting
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent 0ca4f16 commit fb926b2

File tree

2 files changed

+41
-123
lines changed

2 files changed

+41
-123
lines changed

art/defences/detector/poison/clustering_centroid_analysis.py

Lines changed: 22 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,7 @@ def __init__(
150150
# Default reducer, recommended by original authors
151151
self.reducer = UMAP(n_neighbors=5, min_dist=0)
152152
except ImportError as e:
153-
raise ImportError(
154-
"UMAP is required for default reducer in ClusteringCentroidAnalysis. "
155-
) from e
153+
raise ImportError("UMAP is required for default reducer in ClusteringCentroidAnalysis. ") from e
156154

157155
logging.info("Loading variables into CCA...")
158156
super().__init__(classifier, x_train, y_train)
@@ -191,9 +189,7 @@ def __init__(
191189
.prefetch(self._tf_runtime.data.AUTOTUNE)
192190
)
193191

194-
self.feature_representation_model, self.classifying_submodel = self._extract_submodels(
195-
final_feature_layer_name
196-
)
192+
self.feature_representation_model, self.classifying_submodel = self._extract_submodels(final_feature_layer_name)
197193

198194
self.misclassification_threshold = np.float64(misclassification_threshold)
199195

@@ -208,17 +204,13 @@ def __init__(
208204
# Handle case where output_shape[1:] might be an empty tuple if model outputs a scalar.
209205
# Unlikely for feature extractor, but good to be robust.
210206
if not self.feature_output_shape:
211-
logging.warning(
212-
"feature_representation_model.output_shape[1:] is empty. Attempting dummy inference."
213-
)
207+
logging.warning("feature_representation_model.output_shape[1:] is empty. Attempting dummy inference.")
214208
# Fallback to dummy inference if output_shape is unexpectedly empty
215209
dummy_input_shape = (
216210
1,
217211
*self.x_train.shape[1:],
218212
) # Use actual classifier input shape
219-
dummy_input = self._tf_runtime.random.uniform(
220-
shape=dummy_input_shape, dtype=self._tf_runtime.float32
221-
)
213+
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
222214
dummy_features = self.feature_representation_model(dummy_input)
223215
self.feature_output_shape = dummy_features.shape[1:]
224216
else:
@@ -228,9 +220,7 @@ def __init__(
228220
"Performing dummy inference for shape."
229221
)
230222
dummy_input_shape = (1, *self.x_train.shape[1:]) # Use actual classifier input shape
231-
dummy_input = self._tf_runtime.random.uniform(
232-
shape=dummy_input_shape, dtype=self._tf_runtime.float32
233-
)
223+
dummy_input = self._tf_runtime.random.uniform(shape=dummy_input_shape, dtype=self._tf_runtime.float32)
234224
dummy_features = self.feature_representation_model(dummy_input)
235225
self.feature_output_shape = dummy_features.shape[1:]
236226

@@ -251,12 +241,8 @@ def __init__(
251241
self._predict_with_deviation = self._tf_runtime.function(
252242
self._predict_with_deviation_unwrapped,
253243
input_signature=[
254-
self._tf_runtime.TensorSpec(
255-
shape=[None, *self.feature_output_shape], dtype=self._tf_runtime.float32
256-
),
257-
self._tf_runtime.TensorSpec(
258-
shape=self.feature_output_shape, dtype=self._tf_runtime.float32
259-
),
244+
self._tf_runtime.TensorSpec(shape=[None, *self.feature_output_shape], dtype=self._tf_runtime.float32),
245+
self._tf_runtime.TensorSpec(shape=self.feature_output_shape, dtype=self._tf_runtime.float32),
260246
],
261247
reduce_retracing=True,
262248
)
@@ -275,9 +261,7 @@ def _calculate_centroid(self, selected_indices: np.ndarray, features: np.ndarray
275261
:return: d-dimensional numpy array
276262
"""
277263
selected_features = features[selected_indices]
278-
features_tf = self._tf_runtime.convert_to_tensor(
279-
selected_features, dtype=self._tf_runtime.float32
280-
)
264+
features_tf = self._tf_runtime.convert_to_tensor(selected_features, dtype=self._tf_runtime.float32)
281265
centroid = self._calculate_centroid_tf(features_tf)
282266
return centroid.numpy()
283267

@@ -300,9 +284,7 @@ def _class_clustering(
300284
cluster_labels = clusterer.fit_predict(selected_features)
301285
return cluster_labels, selected_indices
302286

303-
def _calculate_features_unwrapped(
304-
self, feature_representation_model: Model, x: np.ndarray
305-
) -> np.ndarray:
287+
def _calculate_features_unwrapped(self, feature_representation_model: Model, x: np.ndarray) -> np.ndarray:
306288
"""
307289
Calculates the features using the first DNN slice
308290
@@ -312,9 +294,7 @@ def _calculate_features_unwrapped(
312294
"""
313295
return feature_representation_model(x, training=False)
314296

315-
def _feature_extraction(
316-
self, x_train: tf_types.data.Dataset, feature_representation_model: Model
317-
) -> np.ndarray:
297+
def _feature_extraction(self, x_train: tf_types.data.Dataset, feature_representation_model: Model) -> np.ndarray:
318298
"""
319299
Extract features from the model using the feature representation sub model.
320300
@@ -361,9 +341,7 @@ def _cluster_classes(
361341
logging.debug("Unique classes are: %s", unique_classes)
362342

363343
for class_label in unique_classes:
364-
cluster_labels, selected_indices = self._class_clustering(
365-
y_train, features, class_label, clusterer
366-
)
344+
cluster_labels, selected_indices = self._class_clustering(y_train, features, class_label, clusterer)
367345
# label values are adjusted to account for labels of previous clustering tasks
368346
cluster_labels[cluster_labels != -1] += used_cluster_labels
369347
used_cluster_labels += len(np.unique(cluster_labels[cluster_labels != -1]))
@@ -383,9 +361,7 @@ def _get_benign_data(self) -> tuple[np.ndarray, np.ndarray]:
383361
:return: (x_benign, y_benign) ndarrays with the benign data.
384362
"""
385363
if len(self.benign_indices) == 0:
386-
raise ValueError(
387-
f"Benign indices passed ({len(self.benign_indices)}) are not enough to run the algorithm"
388-
)
364+
raise ValueError(f"Benign indices passed ({len(self.benign_indices)}) are not enough to run the algorithm")
389365

390366
return self.x_train[self.benign_indices], self.y_train[self.benign_indices]
391367

@@ -404,9 +380,7 @@ def _extract_submodels(self, final_feature_layer_name: str) -> tuple[Model, Mode
404380
try:
405381
final_feature_layer = keras_model.get_layer(name=final_feature_layer_name)
406382
except ValueError as exc:
407-
raise ValueError(
408-
f"Layer with name '{final_feature_layer_name}' not found in the model."
409-
) from exc
383+
raise ValueError(f"Layer with name '{final_feature_layer_name}' not found in the model.") from exc
410384

411385
if (
412386
not hasattr(final_feature_layer, "activation")
@@ -428,9 +402,7 @@ def _extract_submodels(self, final_feature_layer_name: str) -> tuple[Model, Mode
428402
classifier_submodel_layers = keras_model.layers[final_feature_layer_index + 1 :]
429403

430404
# Create the classifier submodel
431-
classifying_submodel = self._KerasSequential(
432-
classifier_submodel_layers, name="classifying_submodel"
433-
)
405+
classifying_submodel = self._KerasSequential(classifier_submodel_layers, name="classifying_submodel")
434406

435407
intermediate_shape = feature_representation_model.output_shape[1:]
436408
dummy_input = self._tf_runtime.zeros((1,) + intermediate_shape)
@@ -487,9 +459,7 @@ def _predict_with_deviation_unwrapped(
487459
deviated_features = self._tf_runtime.nn.relu(features + deviation)
488460
return self.classifying_submodel(deviated_features, training=False)
489461

490-
def _calculate_misclassification_rate(
491-
self, class_label: int, deviation: np.ndarray
492-
) -> np.float64:
462+
def _calculate_misclassification_rate(self, class_label: int, deviation: np.ndarray) -> np.float64:
493463
"""
494464
Calculate the misclassification rate when applying a deviation to other classes.
495465
@@ -528,9 +498,7 @@ def _calculate_misclassification_rate(
528498

529499
# Batches of the class are processed with deviation to determine misclassification
530500
for batch_data_tf in class_x_dataset:
531-
features_tf = self._calculate_features(
532-
self.feature_representation_model, batch_data_tf
533-
)
501+
features_tf = self._calculate_features(self.feature_representation_model, batch_data_tf)
534502

535503
predictions = self._predict_with_deviation(features_tf, deviation_tf)
536504

@@ -561,9 +529,7 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
561529

562530
self.is_clean_np = np.ones(len(self.y_train))
563531

564-
self.features = self._feature_extraction(
565-
self.x_train_dataset, self.feature_representation_model
566-
)
532+
self.features = self._feature_extraction(self.x_train_dataset, self.feature_representation_model)
567533

568534
# Small fix to add flexibility to use CCAUD in non-flattened scenarios. Not recommended, but available
569535
if len(self.features.shape) > 2:
@@ -602,12 +568,8 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
602568

603569
# for each target class
604570
for class_label in self.unique_classes:
605-
benign_class_indices = np.intersect1d(
606-
self.benign_indices, np.where(self.y_train == class_label)[0]
607-
)
608-
benign_centroids[class_label] = self._calculate_centroid(
609-
benign_class_indices, self.features
610-
)
571+
benign_class_indices = np.intersect1d(self.benign_indices, np.where(self.y_train == class_label)[0])
572+
benign_centroids[class_label] = self._calculate_centroid(benign_class_indices, self.features)
611573

612574
logging.info("Calculating misclassification rates...")
613575
misclassification_rates = {}
@@ -620,9 +582,7 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
620582
# MR^k_i
621583
# with unique cluster labels for each cluster in each clustering run, the label
622584
# already maps to a target class
623-
misclassification_rates[cluster_label] = self._calculate_misclassification_rate(
624-
class_label, deviation
625-
)
585+
misclassification_rates[cluster_label] = self._calculate_misclassification_rate(class_label, deviation)
626586
logging.info(
627587
"MR (k=%s, i=%s, |d|=%s) = %s",
628588
cluster_label,
@@ -631,14 +591,10 @@ def detect_poison(self, **kwargs) -> tuple[dict, list[int]]:
631591
misclassification_rates[cluster_label],
632592
)
633593

634-
report["cluster_data"][cluster_label]["centroid_l2"] = np.linalg.norm(
635-
real_centroids[cluster_label]
636-
)
594+
report["cluster_data"][cluster_label]["centroid_l2"] = np.linalg.norm(real_centroids[cluster_label])
637595
report["cluster_data"][cluster_label]["deviation_l2"] = np.linalg.norm(deviation)
638596
report["cluster_data"][cluster_label]["class"] = class_label
639-
report["cluster_data"][cluster_label]["misclassification_rate"] = (
640-
misclassification_rates[cluster_label]
641-
)
597+
report["cluster_data"][cluster_label]["misclassification_rate"] = misclassification_rates[cluster_label]
642598

643599
logging.info("Evaluating cluster misclassification...")
644600
for cluster_label, mr in misclassification_rates.items():

0 commit comments

Comments
 (0)