Skip to content

Commit 0ca4f16

Browse files
committed
Pytest marks added
Signed-off-by: Álvaro Bacca Peña <[email protected]>
1 parent 38a9f55 commit 0ca4f16

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

tests/defences/detector/poison/test_clustering_centroid_analysis.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def fit_predict(self, x, **kwargs):
5454
return self.cluster_labels_to_return
5555

5656

57+
@pytest.mark.skip_framework(
58+
"keras",
59+
"kerastf",
60+
"pytorch",
61+
"non_dl_frameworks",
62+
)
5763
class CCAUDTestCaseBase(unittest.TestCase):
5864

5965
def setUp(self):
@@ -98,12 +104,6 @@ def setUp(self):
98104
)
99105

100106

101-
@pytest.mark.skip_framework(
102-
"keras",
103-
"kerastf",
104-
"pytorch",
105-
"non_dl_frameworks",
106-
)
107107
class TestInitialization(CCAUDTestCaseBase):
108108
"""
109109
Unit tests for the ClusteringCentroidAnalysis class, focusing on
@@ -235,6 +235,7 @@ def test_init_invalid_layer_non_relu(self):
235235
)
236236

237237

238+
@pytest.mark.framework_agnostic
238239
class TestEncodeLabels(unittest.TestCase):
239240
"""
240241
Unit tests for the ClusteringCentroidAnalysis label encoding, needed for proper clustering.
@@ -260,12 +261,6 @@ def test_encode_multi_labels(self):
260261
self.assertEqual({"A": 0, "B": 1, "C": 2, "D": 3}, reverse_mapping)
261262

262263

263-
@pytest.mark.skip_framework(
264-
"keras",
265-
"kerastf",
266-
"pytorch",
267-
"non_dl_frameworks",
268-
)
269264
class TestCalculateCentroid(CCAUDTestCaseBase):
270265
"""
271266
Unit tests for the ClusteringCentroidAnalysis centroid calculations, needed for PCD.
@@ -591,12 +586,6 @@ def fit_predict(self, x, **kwargs):
591586
self.assertEqual(len(cluster_class_mapping), 0)
592587

593588

594-
@pytest.mark.skip_framework(
595-
"keras",
596-
"kerastf",
597-
"pytorch",
598-
"non_dl_frameworkslearn",
599-
)
600589
class TestFeatureExtraction(CCAUDTestCaseBase):
601590
"""Unit tests for the _feature_extraction function."""
602591

@@ -1208,6 +1197,12 @@ def mock_misclass_rate(class_label, deviation):
12081197
self.assertLess(np.mean(self.y_train[np.where(is_clean_np == 1)]), 0.2)
12091198

12101199

1200+
@pytest.mark.skip_framework(
1201+
"keras",
1202+
"kerastf",
1203+
"pytorch",
1204+
"non_dl_frameworks",
1205+
)
12111206
class TestEvaluateDefence(unittest.TestCase):
12121207
"""
12131208
Unit tests for the evaluate_defence method of the ClusteringCentroidAnalysis class.

0 commit comments

Comments
 (0)