|
36 | 36 | from sklearn.cluster import KMeans, MiniBatchKMeans
|
37 | 37 |
|
38 | 38 | from art.data_generators import DataGenerator
|
39 |
| -from art.defences.detector.poison.clustering_analyzer import ClusteringAnalyzer |
40 | 39 | from art.defences.detector.poison.ground_truth_evaluator import GroundTruthEvaluator
|
41 | 40 | from art.defences.detector.poison.poison_filtering_defence import PoisonFilteringDefence
|
42 | 41 | from art.utils import segment_by_class
|
43 | 42 | from art.visualization import create_sprite, save_image, plot_3d
|
44 | 43 |
|
| 44 | +from art.defences.detector.poison.clustering_analyzer import ClusterAnalysisType, get_cluster_analyzer |
| 45 | + |
45 | 46 | if TYPE_CHECKING:
|
46 | 47 | from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
|
47 | 48 |
|
@@ -311,44 +312,27 @@ def analyze_clusters(self, **kwargs) -> tuple[dict[str, Any], np.ndarray]:
|
311 | 312 | :return: (report, assigned_clean_by_class), where the report is a dict object and assigned_clean_by_class
|
312 | 313 | is a list of arrays that contains what data points where classified as clean.
|
313 | 314 | """
|
| 315 | + # default argument setting |
314 | 316 | self.set_params(**kwargs)
|
315 | 317 |
|
316 | 318 | if not self.clusters_by_class:
|
317 | 319 | self.cluster_activations()
|
318 | 320 |
|
319 |
| - analyzer = ClusteringAnalyzer() |
320 |
| - if self.cluster_analysis == "smaller": |
321 |
| - ( |
322 |
| - self.assigned_clean_by_class, |
323 |
| - self.poisonous_clusters, |
324 |
| - report, |
325 |
| - ) = analyzer.analyze_by_size(self.clusters_by_class) |
326 |
| - elif self.cluster_analysis == "relative-size": |
327 |
| - ( |
328 |
| - self.assigned_clean_by_class, |
329 |
| - self.poisonous_clusters, |
330 |
| - report, |
331 |
| - ) = analyzer.analyze_by_relative_size(self.clusters_by_class) |
332 |
| - elif self.cluster_analysis == "distance": |
333 |
| - ( |
334 |
| - self.assigned_clean_by_class, |
335 |
| - self.poisonous_clusters, |
336 |
| - report, |
337 |
| - ) = analyzer.analyze_by_distance( |
338 |
| - self.clusters_by_class, |
339 |
| - separated_activations=self.red_activations_by_class, |
| 321 | + analysis_type = ClusterAnalysisType(self.cluster_analysis) |
| 322 | + analyzer = get_cluster_analyzer(analysis_type) |
| 323 | + |
| 324 | + if analysis_type in [ClusterAnalysisType.SMALLER, ClusterAnalysisType.RELATIVE_SIZE]: |
| 325 | + self.assigned_clean_by_class, self.poisonous_clusters, report = analyzer(self.clusters_by_class) |
| 326 | + elif analysis_type == ClusterAnalysisType.DISTANCE: |
| 327 | + self.assigned_clean_by_class, self.poisonous_clusters, report = analyzer( |
| 328 | + self.clusters_by_class, separated_activations=self.red_activations_by_class |
340 | 329 | )
|
341 |
| - elif self.cluster_analysis == "silhouette-scores": |
342 |
| - ( |
343 |
| - self.assigned_clean_by_class, |
344 |
| - self.poisonous_clusters, |
345 |
| - report, |
346 |
| - ) = analyzer.analyze_by_silhouette_score( |
347 |
| - self.clusters_by_class, |
348 |
| - reduced_activations_by_class=self.red_activations_by_class, |
| 330 | + elif analysis_type == ClusterAnalysisType.SILHOUETTE_SCORES: |
| 331 | + self.assigned_clean_by_class, self.poisonous_clusters, report = analyzer( |
| 332 | + self.clusters_by_class, reduced_activations_by_class=self.red_activations_by_class |
349 | 333 | )
|
350 | 334 | else:
|
351 |
| - raise ValueError("Unsupported cluster analysis technique " + self.cluster_analysis) |
| 335 | + raise ValueError("Unsupported cluster analysis technique " + analysis_type.value) |
352 | 336 |
|
353 | 337 | # Add to the report current parameters used to run the defence and the analysis summary
|
354 | 338 | report = dict(list(report.items()) + list(self.get_params().items()))
|
|
0 commit comments