37
37
from tensorflow .keras .layers import Dense
38
38
from umap import UMAP
39
39
40
- from art .defences .detector .poison .clustering_centroid_analysis import get_reducer , get_clusterer , \
41
- ClusteringCentroidAnalysis , _calculate_centroid , _class_clustering , _feature_extraction , _cluster_classes , \
42
- _encode_labels
40
+ from art .defences .detector .poison .clustering_centroid_analysis import ClusteringCentroidAnalysis , _calculate_centroid , _class_clustering , _feature_extraction , _cluster_classes , _encode_labels
43
41
from art .defences .detector .poison .utils import ReducerType , ClustererType
44
42
45
43
logger = logging .getLogger (__name__ )
@@ -276,7 +274,7 @@ def test_init_invalid_layer_name(self):
276
274
277
275
def test_init_invalid_layer_non_relu (self ):
278
276
"""Test __init__ with an invalid layer that does not have ReLu activation. Check that it raises error."""
279
- with self .assertRaises ( ValueError ) as e :
277
+ with self .assertWarns ( UserWarning ) as w :
280
278
ClusteringCentroidAnalysis (
281
279
classifier = self .mock_classifier ,
282
280
x_train = self .x_train ,
@@ -285,7 +283,9 @@ def test_init_invalid_layer_non_relu(self):
285
283
final_feature_layer_name = self .non_relu_intermediate_layer_name ,
286
284
misclassification_threshold = self .misclassification_threshold
287
285
)
288
- self .assertEqual (f"Final feature layer '{ self .non_relu_intermediate_layer_name } ' must have a ReLU activation." , str (e .exception ))
286
+ self .assertEqual (1 , len (w .warnings ))
287
+ self .assertEqual (f"Final feature layer '{ self .non_relu_intermediate_layer_name } ' must have a ReLU activation." ,
288
+ str (w .warnings [0 ].message ))
289
289
290
290
class TestEncodeLabels (unittest .TestCase ):
291
291
@@ -650,44 +650,6 @@ def test_integration_with_real_model(self):
650
650
self .assertIsInstance (result , np .ndarray )
651
651
652
652
653
- class TestReducersClusterers (unittest .TestCase ):
654
- """
655
- Suite of tests for the valid and invalid utils used in :class: ``ClusteringCentroidAnalysis``
656
- """
657
-
658
- def test_get_reducer_valid (self ):
659
- reducer_cases = [
660
- (ReducerType .FASTICA , FastICA ),
661
- (ReducerType .PCA , PCA ),
662
- (ReducerType .UMAP , UMAP ),
663
- ]
664
- for reducer_type , expected in reducer_cases :
665
- with self .subTest (reducer = reducer_type ):
666
- reducer = get_reducer (reducer_type , nb_dims = 5 )
667
- self .assertIsInstance (reducer , expected )
668
-
669
- def test_get_reducer_invalid (self ):
670
- for invalid in ["INVALID" , None ]:
671
- with self .subTest (invalid = invalid ):
672
- with self .assertRaises (ValueError ):
673
- get_reducer (invalid , nb_dims = 5 )
674
-
675
- def test_get_clusterer_valid (self ):
676
- clusterer_cases = [
677
- (ClustererType .DBSCAN , DBSCAN ),
678
- ]
679
- for clusterer_type , expected in clusterer_cases :
680
- with self .subTest (clusterer = clusterer_type ):
681
- clusterer = get_clusterer (clusterer_type )
682
- self .assertIsInstance (clusterer , expected )
683
-
684
- def test_get_clusterer_invalid (self ):
685
- for invalid in ["INVALID" , None ]:
686
- with self .subTest (invalid = invalid ):
687
- with self .assertRaises (ValueError ):
688
- get_clusterer (invalid )
689
-
690
-
691
653
class TestDetectPoison (unittest .TestCase ):
692
654
"""
693
655
Unit tests for the detect_poison method in ClusteringCentroidAnalysis
0 commit comments