Skip to content

Commit fbd0c48

Browse files
committed
[WIP] evaluate_defence unit tests
Signed-off-by: alvaro <[email protected]>
1 parent 8508a83 commit fbd0c48

File tree

2 files changed

+202
-8
lines changed

2 files changed

+202
-8
lines changed

art/defences/detector/poison/clustering_centroid_analysis.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,18 @@
2121
import warnings
2222
from typing import TYPE_CHECKING
2323

24-
import tensorflow as tf
2524
import numpy as np
26-
from tensorflow.keras import Model, Sequential
27-
28-
from art.defences.detector.poison.ground_truth_evaluator import GroundTruthEvaluator
25+
import tensorflow as tf
2926
from sklearn.base import ClusterMixin
3027
from sklearn.cluster import DBSCAN
31-
from sklearn.decomposition import FastICA, PCA
28+
from tensorflow.keras import Model, Sequential
3229
from umap import UMAP
3330

34-
from art.defences.detector.poison.clustering_analyzer import ClusterAnalysisType
31+
from art.defences.detector.poison.ground_truth_evaluator import GroundTruthEvaluator
3532
from art.defences.detector.poison.poison_filtering_defence import PoisonFilteringDefence
36-
from art.defences.detector.poison.utils import ReducerType, ClustererType
3733

3834
if TYPE_CHECKING:
39-
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE, CLASSIFIER_TYPE
35+
from art.utils import CLASSIFIER_TYPE
4036

4137
logger = logging.getLogger(__name__)
4238
tf.get_logger().setLevel(logging.WARN)

tests/defences/detector/poison/test_clustering_centroid_analysis.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# SOFTWARE.
1818
from __future__ import absolute_import, division, print_function, unicode_literals
1919

20+
import json
2021
from unittest.mock import MagicMock, patch
2122

2223
import tensorflow as tf
@@ -649,6 +650,203 @@ def test_integration_with_real_model(self):
649650
self.assertEqual((100, 5), result.shape)
650651
self.assertIsInstance(result, np.ndarray)
651652

653+
class TestEvaluateDefence(unittest.TestCase):
654+
"""
655+
Unit tests for the evaluate_defence method of the ClusteringCentroidAnalysis class.
656+
"""
657+
658+
def setUp(self):
659+
"""
660+
Set up a mock ClusteringCentroidAnalysis object and its necessary attributes.
661+
"""
662+
self.mock_classifier = MagicMock()
663+
self.mock_classifier.model = MagicMock()
664+
665+
# Dummy data for constructor - these values might not be directly used by
666+
# evaluate_defence but are needed for instantiation.
667+
x_train_dummy = np.array([[1, 2], [3, 4], [5, 6]])
668+
y_train_constructor_dummy = np.array(['A', 'B', 'A'])
669+
benign_indices_dummy = np.array([0, 2])
670+
final_feature_layer_name_dummy = "mock_feature_layer"
671+
misclassification_threshold_dummy = 0.1
672+
673+
# Patch _extract_submodels to avoid complex model setup if it's problematic
674+
# and not relevant to evaluate_defence
675+
with patch('art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysis._extract_submodels',
676+
return_value=(MagicMock(), MagicMock())) as _:
677+
self.defence = ClusteringCentroidAnalysis(
678+
classifier=self.mock_classifier,
679+
x_train=x_train_dummy,
680+
y_train=y_train_constructor_dummy, # Used by _encode_labels in __init__
681+
benign_indices=benign_indices_dummy,
682+
final_feature_layer_name=final_feature_layer_name_dummy,
683+
misclassification_threshold=misclassification_threshold_dummy
684+
)
685+
686+
# The following attributes are set after instantiation to control the test
687+
# environment precisely
688+
689+
self.defence.unique_classes = {0, 1} # e.g., 'A' -> 0, 'B' -> 1
690+
self.defence.y_train = np.array([0, 0, 1, 1, 0]) # Total 5 samples
691+
self.defence.is_clean = np.array([1, 0, 1, 0, 1]) # Predictions by the defence
692+
693+
@patch('art.defences.detector.poison.clustering_centroid_analysis.GroundTruthEvaluator')
694+
def test_evaluate_defence_basic_case(self, MockGroundTruthEvaluator):
695+
"""
696+
Test evaluate_defence with a basic scenario of ground truth and predictions.
697+
"""
698+
# Mock setup
699+
mock_evaluator_instance = MockGroundTruthEvaluator.return_value
700+
expected_json_report = json.dumps({"accuracy": 0.6, "class_0_fp": 1, "class_1_fn": 0})
701+
mock_evaluator_instance.analyze_correctness.return_value = (
702+
{"some_errors": []}, # errors_by_class (not directly used by evaluate_defence's return)
703+
expected_json_report # confusion_matrix_json
704+
)
705+
706+
# Ground truth setup
707+
# This is the `is_clean` array passed as an argument to evaluate_defence
708+
ground_truth_is_clean = np.array([1, 1, 1, 0, 0]) # Ground truth for the 5 samples
709+
710+
returned_json_report = self.defence.evaluate_defence(is_clean=ground_truth_is_clean)
711+
712+
self.assertEqual(returned_json_report, expected_json_report)
713+
714+
# Verify how analyze_correctness was called
715+
mock_evaluator_instance.analyze_correctness.assert_called_once()
716+
call_args = mock_evaluator_instance.analyze_correctness.call_args[1] # Get kwargs
717+
718+
# Expected segmentation based on self.defence.y_train, self.defence.is_clean, and ground_truth_is_clean
719+
# self.defence.y_train = np.array([0, 0, 1, 1, 0])
720+
# self.defence.is_clean (predictions) = np.array([1, 0, 1, 0, 1])
721+
# ground_truth_is_clean (truth) = np.array([1, 1, 1, 0, 0])
722+
723+
# Class 0 indices: 0, 1, 4
724+
# Class 1 indices: 2, 3
725+
726+
expected_assigned_clean_by_class = [
727+
self.defence.is_clean[[0, 1, 4]], # Predictions for class 0: [1, 0, 1]
728+
self.defence.is_clean[[2, 3]] # Predictions for class 1: [1, 0]
729+
]
730+
expected_is_clean_by_class = [
731+
ground_truth_is_clean[[0, 1, 4]], # Ground truth for class 0: [1, 1, 0]
732+
ground_truth_is_clean[[2, 3]] # Ground truth for class 1: [1, 0]
733+
]
734+
735+
# np.testing.assert_equal doesn't work well for lists of arrays directly in assert_called_with
736+
# So we compare element by element
737+
self.assertEqual(len(call_args['assigned_clean_by_class']), len(expected_assigned_clean_by_class))
738+
for i, arr in enumerate(call_args['assigned_clean_by_class']):
739+
np.testing.assert_array_equal(arr, expected_assigned_clean_by_class[i])
740+
741+
self.assertEqual(len(call_args['is_clean_by_class']), len(expected_is_clean_by_class))
742+
for i, arr in enumerate(call_args['is_clean_by_class']):
743+
np.testing.assert_array_equal(arr, expected_is_clean_by_class[i])
744+
745+
746+
@patch('art.defences.detector.poison.clustering_centroid_analysis.GroundTruthEvaluator')
747+
def test_evaluate_defence_all_predicted_clean_all_truth_clean(self, MockGroundTruthEvaluator):
748+
"""
749+
Test case: All samples predicted as clean by defence, and all are truly clean.
750+
"""
751+
mock_evaluator_instance = MockGroundTruthEvaluator.return_value
752+
expected_json_report = json.dumps({"accuracy": 1.0})
753+
mock_evaluator_instance.analyze_correctness.return_value = ({}, expected_json_report)
754+
755+
self.defence.is_clean = np.ones_like(self.defence.y_train) # All predicted clean
756+
ground_truth_is_clean = np.ones_like(self.defence.y_train) # All truly clean
757+
758+
returned_json_report = self.defence.evaluate_defence(is_clean=ground_truth_is_clean)
759+
self.assertEqual(returned_json_report, expected_json_report)
760+
761+
call_args = mock_evaluator_instance.analyze_correctness.call_args[1]
762+
763+
# self.defence.y_train = np.array([0, 0, 1, 1, 0])
764+
# Class 0 indices: 0, 1, 4
765+
# Class 1 indices: 2, 3
766+
expected_assigned_clean_by_class = [np.array([1,1,1]), np.array([1,1])]
767+
expected_is_clean_by_class = [np.array([1,1,1]), np.array([1,1])]
768+
769+
self.assertEqual(len(call_args['assigned_clean_by_class']), len(expected_assigned_clean_by_class))
770+
for i, arr in enumerate(call_args['assigned_clean_by_class']):
771+
np.testing.assert_array_equal(arr, expected_assigned_clean_by_class[i])
772+
773+
self.assertEqual(len(call_args['is_clean_by_class']), len(expected_is_clean_by_class))
774+
for i, arr in enumerate(call_args['is_clean_by_class']):
775+
np.testing.assert_array_equal(arr, expected_is_clean_by_class[i])
776+
777+
778+
@patch('art.defences.detector.poison.clustering_centroid_analysis.GroundTruthEvaluator')
779+
def test_evaluate_defence_all_predicted_poisoned_all_truth_poisoned(self, MockGroundTruthEvaluator):
780+
"""
781+
Test case: All samples predicted as poisoned, and all are truly poisoned.
782+
"""
783+
mock_evaluator_instance = MockGroundTruthEvaluator.return_value
784+
expected_json_report = json.dumps({"accuracy": 1.0, "tn_perfect": True}) # Example detail
785+
mock_evaluator_instance.analyze_correctness.return_value = ({}, expected_json_report)
786+
787+
self.defence.is_clean = np.zeros_like(self.defence.y_train) # All predicted poisoned
788+
ground_truth_is_clean = np.zeros_like(self.defence.y_train) # All truly poisoned
789+
790+
returned_json_report = self.defence.evaluate_defence(is_clean=ground_truth_is_clean)
791+
self.assertEqual(returned_json_report, expected_json_report)
792+
793+
call_args = mock_evaluator_instance.analyze_correctness.call_args[1]
794+
expected_assigned_clean_by_class = [np.array([0,0,0]), np.array([0,0])]
795+
expected_is_clean_by_class = [np.array([0,0,0]), np.array([0,0])]
796+
797+
self.assertEqual(len(call_args['assigned_clean_by_class']), len(expected_assigned_clean_by_class))
798+
for i, arr in enumerate(call_args['assigned_clean_by_class']):
799+
np.testing.assert_array_equal(arr, expected_assigned_clean_by_class[i])
800+
801+
self.assertEqual(len(call_args['is_clean_by_class']), len(expected_is_clean_by_class))
802+
for i, arr in enumerate(call_args['is_clean_by_class']):
803+
np.testing.assert_array_equal(arr, expected_is_clean_by_class[i])
804+
805+
@patch('art.defences.detector.poison.clustering_centroid_analysis.GroundTruthEvaluator')
806+
def test_evaluate_defence_no_samples_for_a_class_in_unique_classes(self, MockGroundTruthEvaluator):
807+
"""
808+
Test case: A class in unique_classes has no samples in y_train (edge case).
809+
This shouldn't happen if unique_classes is derived from y_train correctly,
810+
but tests robustness.
811+
"""
812+
mock_evaluator_instance = MockGroundTruthEvaluator.return_value
813+
expected_json_report = json.dumps({"note": "class 2 had no samples"})
814+
mock_evaluator_instance.analyze_correctness.return_value = ({}, expected_json_report)
815+
816+
self.defence.unique_classes = {0, 1, 2} # Add class 2
817+
# self.defence.y_train remains [0, 0, 1, 1, 0] (no samples for class 2)
818+
self.defence.is_clean = np.array([1, 0, 1, 0, 1])
819+
ground_truth_is_clean = np.array([1, 1, 1, 0, 0])
820+
821+
returned_json_report = self.defence.evaluate_defence(is_clean=ground_truth_is_clean)
822+
self.assertEqual(returned_json_report, expected_json_report)
823+
824+
call_args = mock_evaluator_instance.analyze_correctness.call_args[1]
825+
826+
# Class 0 indices: 0, 1, 4
827+
# Class 1 indices: 2, 3
828+
# Class 2 indices: []
829+
expected_assigned_clean_by_class = [
830+
self.defence.is_clean[[0, 1, 4]],
831+
self.defence.is_clean[[2, 3]],
832+
np.array([]) # Empty for class 2
833+
]
834+
expected_is_clean_by_class = [
835+
ground_truth_is_clean[[0, 1, 4]],
836+
ground_truth_is_clean[[2, 3]],
837+
np.array([]) # Empty for class 2
838+
]
839+
840+
self.assertEqual(len(call_args['assigned_clean_by_class']), len(expected_assigned_clean_by_class))
841+
for i, arr in enumerate(call_args['assigned_clean_by_class']):
842+
np.testing.assert_array_equal(arr, expected_assigned_clean_by_class[i],
843+
err_msg=f"Mismatch in assigned_clean_by_class at index {i}")
844+
845+
self.assertEqual(len(call_args['is_clean_by_class']), len(expected_is_clean_by_class))
846+
for i, arr in enumerate(call_args['is_clean_by_class']):
847+
np.testing.assert_array_equal(arr, expected_is_clean_by_class[i],
848+
err_msg=f"Mismatch in is_clean_by_class at index {i}")
849+
652850

653851
class TestDetectPoison(unittest.TestCase):
654852
"""

0 commit comments

Comments
 (0)