Skip to content

Commit 4f963b3

Browse files
committed
Small fixes
1 parent 7c4d47c commit 4f963b3

File tree

8 files changed

+36
-43
lines changed

8 files changed

+36
-43
lines changed

src/midst_toolkit/attacks/black_box_single_table/ensemble_mia/config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from pathlib import Path
22

3+
34
BASE_DATA_DIR = Path("midst_toolkit/attacks/black_box_single_table/ensemble_mia/data")
45

56
DATA_CONFIG = {
67
# Data processing paths and file names
78
## Input directories:
8-
"midst_data_path": (
9-
BASE_DATA_DIR / "midst_data_all_attacks"
10-
), # Used only for reading the data
9+
"midst_data_path": (BASE_DATA_DIR / "midst_data_all_attacks"), # Used only for reading the data
1110
## Output directories:
12-
"population_path": BASE_DATA_DIR
13-
/ "population_data", # Path where the population data is stored
11+
"population_path": BASE_DATA_DIR / "population_data", # Path where the population data is stored
1412
"processed_attack_data_path": (
1513
BASE_DATA_DIR / "attack_data"
1614
), # Path where the processed attack real train and evaluation data is stored

src/midst_toolkit/attacks/black_box_single_table/ensemble_mia/data_processing/data_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from logging import INFO
22
from pathlib import Path
3-
from midst_toolkit.common.logger import log
3+
44
import pandas as pd
55

6+
from midst_toolkit.common.logger import log
7+
68

79
def save_dataframe(df: pd.DataFrame, file_path: Path, file_name: str) -> None:
810
"""
@@ -47,6 +49,7 @@ def collect_midst_attack_data(
4749
data_dir (Path): The path where the data is stored.
4850
data_split (str): Indicates if this is train, dev, or final data.
4951
dataset (str): The dataset to be collected. Either "train" or "challenge".
52+
data_config (dict): Configuration dictionary containing data paths and file names.
5053
5154
Returns:
5255
pd.DataFrame: The specified dataset in this setting.
@@ -67,16 +70,14 @@ def collect_midst_attack_data(
6770
# Multi-table attacks have different file names.
6871
file_name = (
6972
data_config["multi_table_train_data_file_name"]
70-
if "clavaddpm" == generation_name
73+
if generation_name == "clavaddpm"
7174
else data_config["single_table_train_data_file_name"]
7275
)
7376
assert file_name.split(".")[-1] == "csv", "File name should end with .csv."
7477

7578
df_real = pd.DataFrame()
7679
for i in data_id:
77-
data_dir_ith = (
78-
data_dir / attack_type / data_split / f"{generation_name}_{i}" / file_name
79-
)
80+
data_dir_ith = data_dir / attack_type / data_split / f"{generation_name}_{i}" / file_name
8081
df_real_ith = pd.read_csv(data_dir_ith)
8182
df_real = df_real_ith if i == 1 else pd.concat([df_real, df_real_ith])
8283

src/midst_toolkit/attacks/black_box_single_table/ensemble_mia/data_processing/process_split_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from logging import INFO
2-
from pathlib import Path
32

43
import numpy as np
54
import pandas as pd
65
from sklearn.model_selection import train_test_split
76

8-
from midst_toolkit.common.logger import log
97
from midst_toolkit.attacks.black_box_single_table.ensemble_mia.config import (
108
seed,
119
)
@@ -15,6 +13,7 @@
1513
from midst_toolkit.attacks.black_box_single_table.ensemble_mia.data_processing.real_data_collection import (
1614
collect_population_data_ensemble_mia,
1715
)
16+
from midst_toolkit.common.logger import log
1817

1918

2019
def split_real_data(
@@ -68,7 +67,7 @@ def generate_val_test(
6867
seed: int,
6968
) -> tuple[pd.DataFrame, np.ndarray, pd.DataFrame, np.ndarray]:
7069
"""
71-
Generates the validation and test sets with labels.
70+
Generates the validation and test sets with labels.
7271
The resulting validation and test sets are used for meta classifier training and evaluation, respectively.
7372
7473
Args:
@@ -208,4 +207,5 @@ def process_split_data(
208207
from midst_toolkit.attacks.black_box_single_table.ensemble_mia.config import (
209208
DATA_CONFIG,
210209
)
210+
211211
process_split_data(data_config=DATA_CONFIG)

src/midst_toolkit/attacks/black_box_single_table/ensemble_mia/data_processing/real_data_collection.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ def collect_midst_data(
2121
2222
Args:
2323
attack_types (list[str]): List of attack names to be collected.
24-
data_splits (list[str]): A list indicating the data split to be collected.
24+
data_splits (list[str]): A list indicating the data split to be collected.
2525
Could be any of train, dev, or final data splits.
2626
dataset (str): The dataset to be collected. Either "train" or "challenge".
27+
data_config (dict): Configuration dictionary containing data paths and file names.
2728
2829
Returns:
2930
pd.DataFrame: Collected train or challenge data as a DataFrame.
@@ -60,10 +61,10 @@ def collect_population_data_ensemble_mia(
6061
and returned as a dataframe.
6162
6263
Args:
63-
data_processing_config (dict): Configuration dictionary containing data paths and file names.
64+
data_config (dict): Configuration dictionary containing data paths and file names.
6465
attack_types (list[str] | None): List of attack names to be collected.
6566
If None, all the attacks are collected based on ensemble mia implementation.
66-
67+
6768
Returns:
6869
pd.DataFrame: The collected population data.
6970
"""
@@ -72,17 +73,15 @@ def collect_population_data_ensemble_mia(
7273
# Collect train data of all the attacks (back box and white box)
7374
if attack_types is None:
7475
attack_types = [
75-
"tabddpm_black_box",
76-
"tabsyn_black_box",
77-
"tabddpm_white_box",
78-
"tabsyn_white_box",
79-
"clavaddpm_black_box",
80-
"clavaddpm_white_box",
81-
]
82-
83-
df_population = collect_midst_data(
84-
attack_types, data_splits=["train"], dataset="train", data_config=data_config
85-
)
76+
"tabddpm_black_box",
77+
"tabsyn_black_box",
78+
"tabddpm_white_box",
79+
"tabsyn_white_box",
80+
"clavaddpm_black_box",
81+
"clavaddpm_white_box",
82+
]
83+
84+
df_population = collect_midst_data(attack_types, data_splits=["train"], dataset="train", data_config=data_config)
8685
# Drop ids.
8786
df_population_no_id = df_population.drop(columns=["trans_id", "account_id"])
8887
# Save the population data
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"trans_date": {"size": 2191, "type": "continuous"}, "trans_type": {"size": 3, "type": "discrete"}, "operation": {"size": 6, "type": "discrete"}, "amount": {"size": 40400, "type": "continuous"}, "balance": {"size": 542739, "type": "continuous"}, "k_symbol": {"size": 9, "type": "discrete"}, "bank": {"size": 14, "type": "discrete"}, "account": {"size": 7665, "type": "continuous"}}
1+
{"trans_date": {"size": 2191, "type": "continuous"}, "trans_type": {"size": 3, "type": "discrete"}, "operation": {"size": 6, "type": "discrete"}, "amount": {"size": 40400, "type": "continuous"}, "balance": {"size": 542739, "type": "continuous"}, "k_symbol": {"size": 9, "type": "discrete"}, "bank": {"size": 14, "type": "discrete"}, "account": {"size": 7665, "type": "continuous"}}

tests/unit/attacks/ensemble_mia/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from pathlib import Path
22

3+
34
BASE_DATA_DIR = Path("tests/unit/attacks/ensemble_mia/assets")
45

56
DATA_CONFIG = {
67
# Data processing paths and file names
78
## Input directories:
89
"midst_data_path": BASE_DATA_DIR / "midst_data_all_attacks", # Used only for reading the data
910
## Output directories:
10-
"population_path": BASE_DATA_DIR
11-
/ "population_data", # Path where the population data is stored
11+
"population_path": BASE_DATA_DIR / "population_data", # Path where the population data is stored
1212
"processed_attack_data_path": (
1313
BASE_DATA_DIR / "attack_data"
1414
), # Path where the processed attack real train and evaluation data is stored

tests/unit/attacks/ensemble_mia/test_data_collection.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
2+
23
from src.midst_toolkit.attacks.black_box_single_table.ensemble_mia.data_processing.real_data_collection import (
34
collect_midst_data,
45
collect_population_data_ensemble_mia,
56
)
67
from tests.unit.attacks.ensemble_mia.config import DATA_CONFIG
78

9+
810
def test_collect_population_data_ensemble_mia(tmp_path: Path) -> None:
911
# Comment the next line to update population data stored in DATA_CONFIG["population_path"].
1012
DATA_CONFIG["population_path"] = tmp_path
@@ -21,13 +23,9 @@ def test_collect_population_data_ensemble_mia(tmp_path: Path) -> None:
2123

2224
assert (DATA_CONFIG["population_path"] / "population_all_no_challenge.csv").exists()
2325

24-
assert (
25-
DATA_CONFIG["population_path"] / "population_all_with_challenge.csv"
26-
).exists()
26+
assert (DATA_CONFIG["population_path"] / "population_all_with_challenge.csv").exists()
2727

28-
assert (
29-
DATA_CONFIG["population_path"] / "population_all_with_challenge_no_id.csv"
30-
).exists()
28+
assert (DATA_CONFIG["population_path"] / "population_all_with_challenge_no_id.csv").exists()
3129

3230

3331
def test_collect_midst_data() -> None:

tests/unit/attacks/ensemble_mia/test_process_data_split.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from pathlib import Path
2+
3+
from src.midst_toolkit.attacks.black_box_single_table.ensemble_mia.data_processing.data_utils import load_dataframe
24
from src.midst_toolkit.attacks.black_box_single_table.ensemble_mia.data_processing.process_split_data import (
35
process_split_data,
46
)
57
from tests.unit.attacks.ensemble_mia.config import DATA_CONFIG
68

7-
from src.midst_toolkit.attacks.black_box_single_table.ensemble_mia.data_processing.data_utils import load_dataframe
89

910
def test_process_split_data(tmp_path: Path) -> None:
1011
# Comment the next line to update processed attack data stored in DATA_CONFIG["processed_attack_data_path"].
@@ -38,15 +39,11 @@ def test_process_split_data(tmp_path: Path) -> None:
3839
# Recall that `master_challenge_train`` consists of two halves: one half (10k) from `real_val`` data
3940
# with their "is_train" column set to 0, and the other half (10k) from the real train data (`real_train``)
4041
# with their "is_train" column set to 1. Note that ["is_train"] column is dropped in the final dataframes.
41-
master_challenge_train = load_dataframe(
42-
DATA_CONFIG["processed_attack_data_path"], "master_challenge_train.csv"
43-
)
42+
master_challenge_train = load_dataframe(DATA_CONFIG["processed_attack_data_path"], "master_challenge_train.csv")
4443
assert master_challenge_train.shape == (20000, 10), f" Shape is {master_challenge_train.shape}"
4544

4645
# Recall that `master_challenge_test`` consists of two halves: one half (10k) from `real_test`` data
4746
# with their "is_train" column set to 0, and the other half (10k) from the real train data (`real_train``)
4847
# with their "is_train" column set to 1. Note that ["is_train"] column is dropped in the final dataframes.
49-
master_challenge_test = load_dataframe(
50-
DATA_CONFIG["processed_attack_data_path"], "master_challenge_test.csv"
51-
)
48+
master_challenge_test = load_dataframe(DATA_CONFIG["processed_attack_data_path"], "master_challenge_test.csv")
5249
assert master_challenge_test.shape == (20000, 10), f" Shape is {master_challenge_test.shape}"

0 commit comments

Comments
 (0)