@@ -662,18 +662,26 @@ def setUp(self):
662
662
y_train_constructor_dummy = np .array (["A" ] * 5 + ["B" ] * 5 )
663
663
benign_indices_dummy = np .arange (10 )
664
664
665
- with patch (
666
- "art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._extract_submodels" ,
667
- return_value = (MagicMock (), MagicMock ()),
668
- ):
669
- self .defence = ClusteringCentroidAnalysisTensorFlowV2 (
670
- classifier = MagicMock (),
671
- x_train = x_train_dummy ,
672
- y_train = y_train_constructor_dummy ,
673
- benign_indices = benign_indices_dummy ,
674
- final_feature_layer_name = "dummy_layer" ,
675
- misclassification_threshold = 0.1 ,
676
- )
665
+ self .extract_submodels_patcher = patch (
666
+ "art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._extract_submodels"
667
+ )
668
+ self .mock_extract_submodels = self .extract_submodels_patcher .start ()
669
+
670
+ # Define the return value for mock_extract_submodels: two MagicMocks.
671
+ # These will become self.feature_representation_model and self.classifying_submodel
672
+ # on the self.defence instance.
673
+ mock_feature_model = MagicMock (spec = tf .keras .Model ) # Use spec for type safety
674
+ mock_classifying_model = MagicMock (spec = tf .keras .Model )
675
+ self .mock_extract_submodels .return_value = (mock_feature_model , mock_classifying_model )
676
+
677
+ self .defence = ClusteringCentroidAnalysisTensorFlowV2 (
678
+ classifier = MagicMock (),
679
+ x_train = x_train_dummy ,
680
+ y_train = y_train_constructor_dummy ,
681
+ benign_indices = benign_indices_dummy ,
682
+ final_feature_layer_name = "dummy_layer" ,
683
+ misclassification_threshold = 0.1 ,
684
+ )
677
685
678
686
self .feature_dim = 5
679
687
self .num_benign_samples_class_0 = 3
@@ -693,21 +701,32 @@ def setUp(self):
693
701
)
694
702
self .defence .unique_classes = {0 , 1 , 2 }
695
703
696
- self .defence .feature_representation_model = MagicMock (spec = tf .keras .Model )
704
+ # Mock the instance's _calculate_features attribute AFTER it has been set up in __init__
705
+ # It's now a tf.function, but MagicMock can replace tf.function objects too.
706
+ def mock_calc_features_side_effect (model_arg , data_tensor ): # model_arg is feature_representation_model
707
+ # Ensure it's a tensor for tf.shape, or convert if needed
708
+ if not tf .is_tensor (data_tensor ):
709
+ data_tensor = tf .convert_to_tensor (data_tensor , dtype = tf .float32 )
710
+
711
+ num_samples = tf .shape (data_tensor )[0 ].numpy ()
712
+
713
+ if num_samples == 0 :
714
+ print ("Debug: mock_calculate_features received EMPTY data_tensor. Returning empty features." )
715
+ return tf .constant ([], shape = (0 , self .feature_dim ), dtype = tf .float32 )
716
+
717
+ print (f"Debug: mock_calculate_features received data_tensor shape: { data_tensor .shape } . Returning features shape: ({ num_samples } , { self .feature_dim } )" )
718
+ return tf .random .uniform ((num_samples , self .feature_dim ), dtype = tf .float32 )
719
+
720
+ self .mock_calculate_features_instance = MagicMock (side_effect = mock_calc_features_side_effect )
721
+ self .defence ._calculate_features = self .mock_calculate_features_instance
722
+
723
+ # Mock feature_representation_model.predict as well, as it's used once for feature_shape
697
724
self .defence .feature_representation_model .predict .return_value = np .random .rand (
698
725
1 , self .feature_dim
699
- )
700
- self .defence .classifying_submodel = MagicMock (spec = tf .keras .Sequential )
701
-
702
- self .calculate_features_patcher = patch (
703
- "art.defences.detector.poison.clustering_centroid_analysis.ClusteringCentroidAnalysisTensorFlowV2._calculate_features"
704
- )
705
- self .mock_calculate_features = self .calculate_features_patcher .start ()
726
+ ).astype (np .float32 ) # Ensure correct dtype and shape
706
727
707
728
def tearDown (self ):
708
- self .calculate_features_patcher .stop ()
709
729
tf .config .run_functions_eagerly (self .original_eager_value ) # Restore original eager mode
710
- self .calculate_features_patcher .stop ()
711
730
712
731
def test_zero_misclassification (self ):
713
732
"""Test when no samples are misclassified."""
@@ -716,7 +735,7 @@ def test_zero_misclassification(self):
716
735
717
736
mock_features_class1 = np .random .rand (self .num_benign_samples_class_1 , self .feature_dim )
718
737
mock_features_class2 = np .random .rand (self .num_benign_samples_class_2 , self .feature_dim )
719
- self .mock_calculate_features .side_effect = [
738
+ self .mock_calculate_features_instance .side_effect = [
720
739
mock_features_class1 ,
721
740
mock_features_class2 ,
722
741
]
@@ -748,7 +767,7 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
748
767
749
768
rate = self .defence ._calculate_misclassification_rate (target_class_label , deviation_vector )
750
769
self .assertEqual (rate , 0.0 )
751
- self .assertEqual (self .mock_calculate_features .call_count , 2 )
770
+ self .assertEqual (self .mock_calculate_features_instance .call_count , 2 )
752
771
self .assertEqual (self .defence .classifying_submodel .call_count , 2 )
753
772
754
773
def test_full_misclassification (self ):
@@ -758,7 +777,7 @@ def test_full_misclassification(self):
758
777
759
778
mock_features_class0 = np .random .rand (self .num_benign_samples_class_0 , self .feature_dim )
760
779
mock_features_class2 = np .random .rand (self .num_benign_samples_class_2 , self .feature_dim )
761
- self .mock_calculate_features .side_effect = [
780
+ self .mock_calculate_features_instance .side_effect = [
762
781
mock_features_class0 ,
763
782
mock_features_class2 ,
764
783
]
@@ -784,21 +803,14 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
784
803
785
804
rate = self .defence ._calculate_misclassification_rate (target_class_label , deviation_vector )
786
805
self .assertEqual (rate , 1.0 )
787
- self .assertEqual (self .mock_calculate_features .call_count , 2 )
806
+ self .assertEqual (self .mock_calculate_features_instance .call_count , 2 )
788
807
self .assertEqual (self .defence .classifying_submodel .call_count , 2 )
789
808
790
809
def test_partial_misclassification (self ):
791
810
"""Test with a mix of misclassifications."""
792
811
target_class_label = 2
793
812
deviation_vector = np .random .rand (self .feature_dim )
794
813
795
- mock_features_class0 = np .random .rand (self .num_benign_samples_class_0 , self .feature_dim )
796
- mock_features_class1 = np .random .rand (self .num_benign_samples_class_1 , self .feature_dim )
797
- self .mock_calculate_features .side_effect = [
798
- mock_features_class0 ,
799
- mock_features_class1 ,
800
- ]
801
-
802
814
def mock_classifier_predict_side_effect (deviated_features , training = False ):
803
815
num_unique_classes = len (self .defence .unique_classes )
804
816
num_samples_concrete = tf .compat .v1 .dimension_value (
@@ -839,7 +851,7 @@ def mock_classifier_predict_side_effect(deviated_features, training=False):
839
851
2.0 / (self .num_benign_samples_class_0 + self .num_benign_samples_class_1 ),
840
852
places = 6 ,
841
853
)
842
- self .assertEqual (self .mock_calculate_features .call_count , 2 )
854
+ self .assertEqual (self .mock_calculate_features_instance .call_count , 2 )
843
855
self .assertEqual (self .defence .classifying_submodel .call_count , 2 )
844
856
845
857
def test_no_other_classes_exist (self ):
@@ -850,7 +862,7 @@ def test_no_other_classes_exist(self):
850
862
851
863
rate = self .defence ._calculate_misclassification_rate (target_class_label , deviation_vector )
852
864
self .assertEqual (rate , 0.0 )
853
- self .mock_calculate_features .assert_not_called ()
865
+ self .mock_calculate_features_instance .assert_not_called ()
854
866
self .defence .classifying_submodel .assert_not_called ()
855
867
856
868
def test_other_classes_exist_but_no_benign_samples (self ):
@@ -863,7 +875,7 @@ def test_other_classes_exist_but_no_benign_samples(self):
863
875
864
876
rate = self .defence ._calculate_misclassification_rate (target_class_label , deviation_vector )
865
877
self .assertEqual (rate , 0.0 )
866
- self .mock_calculate_features .assert_not_called ()
878
+ self .mock_calculate_features_instance .assert_not_called ()
867
879
self .defence .classifying_submodel .assert_not_called ()
868
880
869
881
def test_batching_multiple_batches_for_one_class (self ):
@@ -890,7 +902,7 @@ def test_batching_multiple_batches_for_one_class(self):
890
902
features_batch1_c1 = np .random .rand (128 , self .feature_dim )
891
903
features_batch2_c1 = np .random .rand (num_samples_class1_large - 128 , self .feature_dim )
892
904
features_batch1_c2 = np .random .rand (num_samples_class2_small , self .feature_dim )
893
- self .mock_calculate_features .side_effect = [
905
+ self .mock_calculate_features_instance .side_effect = [
894
906
features_batch1_c1 ,
895
907
features_batch2_c1 ,
896
908
features_batch1_c2 ,
@@ -921,7 +933,7 @@ def mock_classifier_predict_side_effect_for_batching_test(
921
933
922
934
rate = self .defence ._calculate_misclassification_rate (target_class_label , deviation_vector )
923
935
self .assertEqual (rate , 1.0 )
924
- self .assertEqual (self .mock_calculate_features .call_count , 3 )
936
+ self .assertEqual (self .mock_calculate_features_instance .call_count , 3 )
925
937
self .assertEqual (self .defence .classifying_submodel .call_count , 3 )
926
938
927
939
self .num_benign_samples_class_1 = original_num_benign_samples_class_1
0 commit comments