@@ -35,12 +35,14 @@ def test_membership_leakage_decision_tree(art_warning, decision_tree_estimator,
3535 extra_classifier = decision_tree_estimator ()
3636 (x_train , y_train ), _ = get_iris_dataset
3737 prev = classifier .model .tree_
38- leakage = PDTP (classifier , extra_classifier , x_train , y_train )
39- logger .info ("Average PDTP leakage: %.2f" , (np .average (leakage )))
40- logger .info ("Max PDTP leakage: %.2f" , (np .max (leakage )))
38+ avg_leakage , worse_leakage = PDTP (classifier , extra_classifier , x_train , y_train )
39+ logger .info ("Average PDTP leakage: %.2f" , (np .average (avg_leakage )))
40+ logger .info ("Max PDTP leakage: %.2f" , (np .max (avg_leakage )))
4141 assert classifier .model .tree_ == prev
42- assert np .all (leakage >= 1.0 )
43- assert leakage .shape [0 ] == x_train .shape [0 ]
42+ assert np .all (avg_leakage >= 1.0 )
43+ assert np .all (worse_leakage >= avg_leakage )
44+ assert avg_leakage .shape [0 ] == x_train .shape [0 ]
45+ assert worse_leakage .shape [0 ] == x_train .shape [0 ]
4446 except ARTTestException as e :
4547 art_warning (e )
4648
@@ -51,11 +53,13 @@ def test_membership_leakage_tabular(art_warning, tabular_dl_estimator, get_iris_
5153 classifier = tabular_dl_estimator ()
5254 extra_classifier = tabular_dl_estimator ()
5355 (x_train , y_train ), _ = get_iris_dataset
54- leakage = PDTP (classifier , extra_classifier , x_train , y_train )
55- logger .info ("Average PDTP leakage: %.2f" , (np .average (leakage )))
56- logger .info ("Max PDTP leakage: %.2f" , (np .max (leakage )))
57- assert np .all (leakage >= 1.0 )
58- assert leakage .shape [0 ] == x_train .shape [0 ]
56+ avg_leakage , worse_leakage = PDTP (classifier , extra_classifier , x_train , y_train )
57+ logger .info ("Average PDTP leakage: %.2f" , (np .average (avg_leakage )))
58+ logger .info ("Max PDTP leakage: %.2f" , (np .max (avg_leakage )))
59+ assert np .all (avg_leakage >= 1.0 )
60+ assert np .all (worse_leakage >= avg_leakage )
61+ assert avg_leakage .shape [0 ] == x_train .shape [0 ]
62+ assert worse_leakage .shape [0 ] == x_train .shape [0 ]
5963 except ARTTestException as e :
6064 art_warning (e )
6165
@@ -67,11 +71,13 @@ def test_membership_leakage_image(art_warning, image_dl_estimator, get_default_m
6771 extra_classifier , _ = image_dl_estimator ()
6872 (x_train , y_train ), _ = get_default_mnist_subset
6973 indexes = random .sample (range (x_train .shape [0 ]), 100 )
70- leakage = PDTP (classifier , extra_classifier , x_train , y_train , indexes = indexes , num_iter = 1 )
71- logger .info ("Average PDTP leakage: %.2f" , (np .average (leakage )))
72- logger .info ("Max PDTP leakage: %.2f" , (np .max (leakage )))
73- assert np .all (leakage >= 1.0 )
74- assert leakage .shape [0 ] == len (indexes )
74+ avg_leakage , worse_leakage = PDTP (classifier , extra_classifier , x_train , y_train , indexes = indexes , num_iter = 1 )
75+ logger .info ("Average PDTP leakage: %.2f" , (np .average (avg_leakage )))
76+ logger .info ("Max PDTP leakage: %.2f" , (np .max (avg_leakage )))
77+ assert np .all (avg_leakage >= 1.0 )
78+ assert np .all (worse_leakage >= avg_leakage )
79+ assert avg_leakage .shape [0 ] == x_train .shape [0 ]
80+ assert worse_leakage .shape [0 ] == x_train .shape [0 ]
7581 except ARTTestException as e :
7682 art_warning (e )
7783
0 commit comments