@@ -402,6 +402,9 @@ def predict_with_deviation(features, deviation):
402
402
403
403
def detect_poison (self , ** kwargs ) -> (dict , list [int ]):
404
404
405
+ # saves important information about the algorithm execution for further analysis
406
+ report = dict ()
407
+
405
408
self .is_clean = np .ones (len (self .y_train ))
406
409
407
410
self .features = _feature_extraction (self .x_train , self .feature_representation_model )
@@ -422,13 +425,20 @@ def detect_poison(self, **kwargs) -> (dict, list[int]):
422
425
outlier_indices = np .where (self .class_cluster_labels == - 1 )[0 ]
423
426
self .is_clean [outlier_indices ] = 0
424
427
428
+ # cluster labels are saved in the report
429
+ report ["cluster_labels" ] = self .get_clusters ()
430
+ report ["cluster_data" ] = dict ()
431
+
425
432
logging .info ("Calculating real centroids..." )
426
433
real_centroids = dict ()
427
434
428
435
# for each cluster found for each target class
429
436
for label in np .unique (self .class_cluster_labels [self .class_cluster_labels != - 1 ]):
430
- real_centroids [label ] = _calculate_centroid (np .where (self .class_cluster_labels == label )[0 ],
431
- self .features )
437
+ selected_elements = np .where (self .class_cluster_labels == label )[0 ]
438
+ real_centroids [label ] = _calculate_centroid (selected_elements , self .features )
439
+
440
+ report ["cluster_data" ][label ] = dict ()
441
+ report ["cluster_data" ][label ]["size" ] = len (selected_elements )
432
442
433
443
logging .info ("Calculating benign centroids..." )
434
444
benign_centroids = dict ()
@@ -451,10 +461,14 @@ def detect_poison(self, **kwargs) -> (dict, list[int]):
451
461
452
462
# MR^k_i
453
463
# with unique cluster labels for each cluster in each clustering run, the label already maps to a target class
454
- logging .info (f"MR (k={ cluster_label } , i={ class_label } , |d|={ np .linalg .norm (deviation )} )..." )
455
464
misclassification_rates [cluster_label ] = self ._calculate_misclassification_rate (class_label , deviation )
456
465
logging .info (f"MR (k={ cluster_label } , i={ class_label } , |d|={ np .linalg .norm (deviation )} ) = { misclassification_rates [cluster_label ]} " )
457
466
467
+ report ["cluster_data" ][cluster_label ]["centroid_l2" ] = np .linalg .norm (real_centroids [cluster_label ])
468
+ report ["cluster_data" ][cluster_label ]["deviation_l2" ] = np .linalg .norm (deviation )
469
+ report ["cluster_data" ][cluster_label ]["class" ] = class_label
470
+ report ["cluster_data" ][cluster_label ]["misclassification_rate" ] = misclassification_rates [cluster_label ]
471
+
458
472
459
473
logging .info ("Evaluating cluster misclassification..." )
460
474
for cluster_label , mr in misclassification_rates .items ():
@@ -464,7 +478,7 @@ def detect_poison(self, **kwargs) -> (dict, list[int]):
464
478
self .is_clean [cluster_indices ] = 0
465
479
logging .info (f"Cluster k={ cluster_label } i={ self .cluster_class_mapping [cluster_label ]} considered poison ({ misclassification_rates [cluster_label ]} >= { 1 - self .misclassification_threshold } )" )
466
480
467
- return dict () , self .is_clean .copy ()
481
+ return report , self .is_clean .copy ()
468
482
469
483
470
484
def get_reducer (reduce : ReducerType , nb_dims : int ):
0 commit comments