@@ -66,7 +66,7 @@ def test_membership_leakage_tabular(art_warning, tabular_dl_estimator, get_iris_
6666 art_warning (e )
6767
6868
69- @pytest .mark .skip_framework ("keras" , "kerastf" , "tensorflow1" , "tensorflow2v1" , "mxnet" )
69+ @pytest .mark .skip_framework ("scikitlearn" , " keras" , "kerastf" , "tensorflow1" , "tensorflow2v1" , "mxnet" )
7070def test_membership_leakage_image (art_warning , image_dl_estimator , get_default_mnist_subset ):
7171 try :
7272 classifier , _ = image_dl_estimator ()
@@ -80,14 +80,14 @@ def test_membership_leakage_image(art_warning, image_dl_estimator, get_default_m
8080 logger .info ("Max PDTP leakage: %.2f" , (np .max (avg_leakage )))
8181 assert np .all (avg_leakage >= 1.0 )
8282 assert np .all (worse_leakage >= avg_leakage )
83- assert avg_leakage .shape [0 ] == x_train . shape [ 0 ]
84- assert worse_leakage .shape [0 ] == x_train . shape [ 0 ]
85- assert std_dev .shape [0 ] == x_train . shape [ 0 ]
83+ assert avg_leakage .shape [0 ] == 100
84+ assert worse_leakage .shape [0 ] == 100
85+ assert std_dev .shape [0 ] == 100
8686 except ARTTestException as e :
8787 art_warning (e )
8888
8989
90- @pytest .mark .skip_framework ("keras" , "kerastf" , "tensorflow1" , "mxnet" )
90+ @pytest .mark .skip_framework ("scikitlearn" , " keras" , "kerastf" , "tensorflow1" , "mxnet" )
9191def test_errors (art_warning , tabular_dl_estimator , get_iris_dataset , image_data_generator ):
9292 try :
9393 classifier = tabular_dl_estimator ()
0 commit comments