Skip to content

Commit 7bf76df

Browse files
committed
TestCalculateMisclassificationRate fixes
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent 8a4f9e1 commit 7bf76df

File tree

1 file changed

+50
-38
lines changed

1 file changed

+50
-38
lines changed

tests/defences/detector/poison/test_clustering_centroid_analysis.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -662,18 +662,26 @@ def setUp(self):
662662
y_train_constructor_dummy = np.array(["A"] * 5 + ["B"] * 5)
663663
benign_indices_dummy = np.arange(10)
664664

665-
with patch(
666-
"art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._extract_submodels",
667-
return_value=(MagicMock(), MagicMock()),
668-
):
669-
self.defence = ClusteringCentroidAnalysisTensorFlowV2(
670-
classifier=MagicMock(),
671-
x_train=x_train_dummy,
672-
y_train=y_train_constructor_dummy,
673-
benign_indices=benign_indices_dummy,
674-
final_feature_layer_name="dummy_layer",
675-
misclassification_threshold=0.1,
676-
)
665+
self.extract_submodels_patcher = patch(
666+
"art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._extract_submodels"
667+
)
668+
self.mock_extract_submodels = self.extract_submodels_patcher.start()
669+
670+
# Define the return value for mock_extract_submodels: two MagicMocks.
671+
# These will become self.feature_representation_model and self.classifying_submodel
672+
# on the self.defence instance.
673+
mock_feature_model = MagicMock(spec=tf.keras.Model) # Use spec for type safety
674+
mock_classifying_model = MagicMock(spec=tf.keras.Model)
675+
self.mock_extract_submodels.return_value = (mock_feature_model, mock_classifying_model)
676+
677+
self.defence = ClusteringCentroidAnalysisTensorFlowV2(
678+
classifier=MagicMock(),
679+
x_train=x_train_dummy,
680+
y_train=y_train_constructor_dummy,
681+
benign_indices=benign_indices_dummy,
682+
final_feature_layer_name="dummy_layer",
683+
misclassification_threshold=0.1,
684+
)
677685

678686
self.feature_dim = 5
679687
self.num_benign_samples_class_0 = 3
@@ -693,21 +701,32 @@ def setUp(self):
693701
)
694702
self.defence.unique_classes = {0, 1, 2}
695703

696-
self.defence.feature_representation_model = MagicMock(spec=tf.keras.Model)
704+
# Mock the instance's _calculate_features attribute AFTER it has been set up in __init__
705+
# It's now a tf.function, but MagicMock can replace tf.function objects too.
706+
def mock_calc_features_side_effect(model_arg, data_tensor): # model_arg is feature_representation_model
707+
# Ensure it's a tensor for tf.shape, or convert if needed
708+
if not tf.is_tensor(data_tensor):
709+
data_tensor = tf.convert_to_tensor(data_tensor, dtype=tf.float32)
710+
711+
num_samples = tf.shape(data_tensor)[0].numpy()
712+
713+
if num_samples == 0:
714+
print("Debug: mock_calculate_features received EMPTY data_tensor. Returning empty features.")
715+
return tf.constant([], shape=(0, self.feature_dim), dtype=tf.float32)
716+
717+
print(f"Debug: mock_calculate_features received data_tensor shape: {data_tensor.shape}. Returning features shape: ({num_samples}, {self.feature_dim})")
718+
return tf.random.uniform((num_samples, self.feature_dim), dtype=tf.float32)
719+
720+
self.mock_calculate_features_instance = MagicMock(side_effect=mock_calc_features_side_effect)
721+
self.defence._calculate_features = self.mock_calculate_features_instance
722+
723+
# Mock feature_representation_model.predict as well, as it's used once for feature_shape
697724
self.defence.feature_representation_model.predict.return_value = np.random.rand(
698725
1, self.feature_dim
699-
)
700-
self.defence.classifying_submodel = MagicMock(spec=tf.keras.Sequential)
701-
702-
self.calculate_features_patcher = patch(
703-
"art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._calculate_features"
704-
)
705-
self.mock_calculate_features = self.calculate_features_patcher.start()
726+
).astype(np.float32) # Ensure correct dtype and shape
706727

707728
def tearDown(self):
708-
self.calculate_features_patcher.stop()
709729
tf.config.run_functions_eagerly(self.original_eager_value) # Restore original eager mode
710-
self.calculate_features_patcher.stop()
711730

712731
def test_zero_misclassification(self):
713732
"""Test when no samples are misclassified."""
@@ -716,7 +735,7 @@ def test_zero_misclassification(self):
716735

717736
mock_features_class1 = np.random.rand(self.num_benign_samples_class_1, self.feature_dim)
718737
mock_features_class2 = np.random.rand(self.num_benign_samples_class_2, self.feature_dim)
719-
self.mock_calculate_features.side_effect = [
738+
self.mock_calculate_features_instance.side_effect = [
720739
mock_features_class1,
721740
mock_features_class2,
722741
]
@@ -748,7 +767,7 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
748767

749768
rate = self.defence._calculate_misclassification_rate(target_class_label, deviation_vector)
750769
self.assertEqual(rate, 0.0)
751-
self.assertEqual(self.mock_calculate_features.call_count, 2)
770+
self.assertEqual(self.mock_calculate_features_instance.call_count, 2)
752771
self.assertEqual(self.defence.classifying_submodel.call_count, 2)
753772

754773
def test_full_misclassification(self):
@@ -758,7 +777,7 @@ def test_full_misclassification(self):
758777

759778
mock_features_class0 = np.random.rand(self.num_benign_samples_class_0, self.feature_dim)
760779
mock_features_class2 = np.random.rand(self.num_benign_samples_class_2, self.feature_dim)
761-
self.mock_calculate_features.side_effect = [
780+
self.mock_calculate_features_instance.side_effect = [
762781
mock_features_class0,
763782
mock_features_class2,
764783
]
@@ -784,21 +803,14 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
784803

785804
rate = self.defence._calculate_misclassification_rate(target_class_label, deviation_vector)
786805
self.assertEqual(rate, 1.0)
787-
self.assertEqual(self.mock_calculate_features.call_count, 2)
806+
self.assertEqual(self.mock_calculate_features_instance.call_count, 2)
788807
self.assertEqual(self.defence.classifying_submodel.call_count, 2)
789808

790809
def test_partial_misclassification(self):
791810
"""Test with a mix of misclassifications."""
792811
target_class_label = 2
793812
deviation_vector = np.random.rand(self.feature_dim)
794813

795-
mock_features_class0 = np.random.rand(self.num_benign_samples_class_0, self.feature_dim)
796-
mock_features_class1 = np.random.rand(self.num_benign_samples_class_1, self.feature_dim)
797-
self.mock_calculate_features.side_effect = [
798-
mock_features_class0,
799-
mock_features_class1,
800-
]
801-
802814
def mock_classifier_predict_side_effect(deviated_features, training=False):
803815
num_unique_classes = len(self.defence.unique_classes)
804816
num_samples_concrete = tf.compat.v1.dimension_value(
@@ -839,7 +851,7 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
839851
2.0 / (self.num_benign_samples_class_0 + self.num_benign_samples_class_1),
840852
places=6,
841853
)
842-
self.assertEqual(self.mock_calculate_features.call_count, 2)
854+
self.assertEqual(self.mock_calculate_features_instance.call_count, 2)
843855
self.assertEqual(self.defence.classifying_submodel.call_count, 2)
844856

845857
def test_no_other_classes_exist(self):
@@ -850,7 +862,7 @@ def test_no_other_classes_exist(self):
850862

851863
rate = self.defence._calculate_misclassification_rate(target_class_label, deviation_vector)
852864
self.assertEqual(rate, 0.0)
853-
self.mock_calculate_features.assert_not_called()
865+
self.mock_calculate_features_instance.assert_not_called()
854866
self.defence.classifying_submodel.assert_not_called()
855867

856868
def test_other_classes_exist_but_no_benign_samples(self):
@@ -863,7 +875,7 @@ def test_other_classes_exist_but_no_benign_samples(self):
863875

864876
rate = self.defence._calculate_misclassification_rate(target_class_label, deviation_vector)
865877
self.assertEqual(rate, 0.0)
866-
self.mock_calculate_features.assert_not_called()
878+
self.mock_calculate_features_instance.assert_not_called()
867879
self.defence.classifying_submodel.assert_not_called()
868880

869881
def test_batching_multiple_batches_for_one_class(self):
@@ -890,7 +902,7 @@ def test_batching_multiple_batches_for_one_class(self):
890902
features_batch1_c1 = np.random.rand(128, self.feature_dim)
891903
features_batch2_c1 = np.random.rand(num_samples_class1_large - 128, self.feature_dim)
892904
features_batch1_c2 = np.random.rand(num_samples_class2_small, self.feature_dim)
893-
self.mock_calculate_features.side_effect = [
905+
self.mock_calculate_features_instance.side_effect = [
894906
features_batch1_c1,
895907
features_batch2_c1,
896908
features_batch1_c2,
@@ -921,7 +933,7 @@ def mock_classifier_predict_side_effect_for_batching_test(
921933

922934
rate = self.defence._calculate_misclassification_rate(target_class_label, deviation_vector)
923935
self.assertEqual(rate, 1.0)
924-
self.assertEqual(self.mock_calculate_features.call_count, 3)
936+
self.assertEqual(self.mock_calculate_features_instance.call_count, 3)
925937
self.assertEqual(self.defence.classifying_submodel.call_count, 3)
926938

927939
self.num_benign_samples_class_1 = original_num_benign_samples_class_1

0 commit comments

Comments
 (0)