Skip to content

Commit 4a0674b

Browse files
committed
Second set of test fix
1 parent ce534ee commit 4a0674b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

examples/ept_attack/run_ept_attack.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe, save_dataframe
2222
from midst_toolkit.attacks.ept.classification import train_attack_classifier
2323
from midst_toolkit.attacks.ept.feature_extraction import extract_features
24+
from midst_toolkit.common.random import set_all_random_seeds
2425
from midst_toolkit.common.logger import log
2526

2627

@@ -255,6 +256,10 @@ def main(config: DictConfig) -> None:
255256
"""
256257
log(INFO, "Running EPT-MIA Attack Example Pipeline.")
257258

259+
if config.random_seed is not None:
260+
set_all_random_seeds(seed=config.random_seed)
261+
log(INFO, f"Training phase random seed set to {config.random_seed}.")
262+
258263
if config.attack_settings.single_table:
259264
log(INFO, "Data: Single-table.")
260265
else:

tests/unit/evaluation/quality/test_mean_f1_score_difference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_mean_f1_score_diff_with_preprocess() -> None:
5050
assert pytest.approx(-0.06789999999999999, abs=1e-8) == score["mean_f1_difference"]
5151
else:
5252
assert pytest.approx(0.7656, abs=1e-8) == score["random_forest_real_train_f1"]
53-
assert pytest.approx(-0.06829999999999997, abs=1e-8) == score["mean_f1_difference"]
53+
assert pytest.approx(-0.06759999999999997, abs=1e-8) == score["mean_f1_difference"]
5454
assert pytest.approx(0.5008, abs=1e-8) == score["random_forest_synthetic_train_f1"]
5555
assert pytest.approx(0.5, abs=1e-8) == score["adaboost_real_train_f1"]
5656
assert pytest.approx(0.49879999999999997, abs=1e-8) == score["adaboost_synthetic_train_f1"]
@@ -78,7 +78,7 @@ def test_mean_f1_score_diff_with_no_categorical() -> None:
7878
if is_apple_silicon():
7979
assert pytest.approx(-0.0792, abs=1e-8) == score["mean_f1_difference"]
8080
else:
81-
assert pytest.approx(-0.0793, abs=1e-8) == score["mean_f1_difference"]
81+
assert pytest.approx(-0.0794, abs=1e-8) == score["mean_f1_difference"]
8282
unset_all_random_seeds()
8383

8484

@@ -103,7 +103,7 @@ def test_mean_f1_score_diff_with_holdout_difference_f1() -> None:
103103
assert pytest.approx(-0.17912194553312424, abs=1e-8) == score["mean_f1_difference"]
104104
else:
105105
assert pytest.approx(0.7655658771879633, abs=1e-8) == score["random_forest_real_train_f1"]
106-
assert pytest.approx(-0.1795346169714074, abs=1e-8) == score["mean_f1_difference"]
106+
assert pytest.approx(-0.17883409886106752, abs=1e-8) == score["mean_f1_difference"]
107107
assert pytest.approx(0.40831722022666145, abs=1e-8) == score["random_forest_synthetic_train_f1"]
108108
assert pytest.approx(0.3632940727026944, abs=1e-8) == score["adaboost_real_train_f1"]
109109
assert pytest.approx(0.33490261584802905, abs=1e-8) == score["adaboost_synthetic_train_f1"]

0 commit comments

Comments
 (0)