Skip to content

Commit 8e764e1

Browse files
committed
fix
1 parent 6a6b99d commit 8e764e1

File tree

10 files changed

+49
-38
lines changed

10 files changed

+49
-38
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ wheels/
2424
**/workspace/*.bkp
2525

2626
# Dataset files
27-
examples/**/data/
27+
examples/**/data/

examples/ensemble_attack_example/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,3 @@ TODO: include the illustrations
3232

3333
## Terminology
3434
To be added....
35-
36-

examples/ensemble_attack_example/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ data_processing_config:
4040
trans_json_file_path: ${base_example_dir}/data_configs/trans.json
4141
population_sample_size: 40000
4242

43-
# Training settings (placeholder)
43+
# Training settings (placeholder)
4444
shadow_training:
4545
epochs: 10
4646
learning_rate: 0.001
4747
batch_size: 64
4848
model_type: "tabddpm"
4949

5050
# General settings
51-
random_seed: 42
51+
random_seed: 42

examples/ensemble_attack_example/real_data_collection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
This data collection script is tailored to the structure of the provided folders in
33
MIDST competition.
44
"""
5+
56
from pathlib import Path
7+
68
import pandas as pd
79
from omegaconf import DictConfig
810

@@ -12,6 +14,15 @@
1214

1315

1416
def expand_ranges(ranges):
17+
"""
18+
Reads a list of tuples representing ranges and expands them into a flat list of integers.
19+
20+
Args:
21+
ranges: List of tuples, where each tuple contains two integers (start, end).
22+
23+
Returns:
24+
A flat list of integers covering the ranges.
25+
"""
1526
expanded = []
1627
for r in ranges:
1728
start, end = r
@@ -62,9 +73,7 @@ def collect_midst_attack_data(
6273

6374
df_real = pd.DataFrame()
6475
for i in data_id:
65-
data_dir_ith = (
66-
data_dir / attack_type / data_split / f"{generation_name}_{i}" / file_name
67-
)
76+
data_dir_ith = data_dir / attack_type / data_split / f"{generation_name}_{i}" / file_name
6877
df_real_ith = pd.read_csv(data_dir_ith)
6978
df_real = df_real_ith if i == 1 else pd.concat([df_real, df_real_ith])
7079

@@ -134,7 +143,6 @@ def collect_population_data_ensemble(
134143
Returns:
135144
The collected population data as a dataframe.
136145
"""
137-
138146
# Ensemble Attack collects train data of all the attack types (back box and white box)
139147
attack_types = data_processing_config.collect_attack_data_types
140148
df_population = collect_midst_data(

examples/ensemble_attack_example/run_attack.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
1-
""" This file is an uncompleted example script for running the ensemble attack on MIDST challenge provided resources and data. """
1+
"""
2+
This file is an uncompleted example script for running the Ensemble Attack on MIDST challenge
3+
provided resources and data.
4+
"""
5+
26
from logging import INFO
7+
from pathlib import Path
8+
39
import hydra
410
from omegaconf import DictConfig
5-
from pathlib import Path
11+
12+
from examples.ensemble_attack_example.real_data_collection import collect_population_data_ensemble
613
from src.midst_toolkit.attacks.ensemble.process_split_data import process_split_data
714
from src.midst_toolkit.common.logger import log
8-
from examples.ensemble_attack_example.real_data_collection import collect_population_data_ensemble
915

1016

1117
@hydra.main(config_path=".", config_name="config", version_base=None)
1218
def main(cfg: DictConfig):
19+
"""
20+
Run the Ensemble Attack example pipeline.
21+
As the first step, data processing is done.
1322
23+
Args:
24+
cfg: Attack OmegaConf DictConfig object.
25+
"""
1426
if cfg.pipeline.run_data_processing:
1527
log(INFO, "Running data processing pipeline...")
1628
# Collect the real data from the MIDST challenge resources.

src/midst_toolkit/attacks/ensemble/process_split_data.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from logging import INFO
22
from pathlib import Path
3+
34
import numpy as np
45
import pandas as pd
56
from sklearn.model_selection import train_test_split
67

7-
from midst_toolkit.attacks.ensemble.utils import (
8+
from midst_toolkit.common.logger import log
9+
from src.midst_toolkit.attacks.ensemble.utils import (
810
save_dataframe,
911
)
1012

11-
from midst_toolkit.common.logger import log
12-
1313

1414
def split_real_data(
1515
df_real: pd.DataFrame,
@@ -24,6 +24,7 @@ def split_real_data(
2424
column_to_stratify: Column name to use for stratified splitting.
2525
proportion: Proportions for train and validation splits.
2626
random_seed: Random seed for reproducibility.
27+
2728
Returns:
2829
A tuple containing the train, validation, and test dataframes.
2930
"""
@@ -41,8 +42,7 @@ def split_real_data(
4142
# Further split the control into val and test set:
4243
df_real_val, df_real_test = train_test_split(
4344
df_real_control,
44-
test_size=(1 - proportion["train"] - proportion["val"])
45-
/ (1 - proportion["train"]),
45+
test_size=(1 - proportion["train"] - proportion["val"]) / (1 - proportion["train"]),
4646
random_state=random_seed,
4747
stratify=df_real_control[column_to_stratify],
4848
)
@@ -121,9 +121,7 @@ def generate_val_test(
121121
ignore_index=True,
122122
)
123123

124-
df_test = df_test.sample(frac=1, random_state=random_seed).reset_index(
125-
drop=True
126-
)
124+
df_test = df_test.sample(frac=1, random_state=random_seed).reset_index(drop=True)
127125

128126
y_test = df_test["is_train"].values
129127
df_test = df_test.drop(columns=["is_train"])
@@ -148,20 +146,17 @@ def process_split_data(
148146
num_total_samples: Number os samples that I randomly selected from the population. Defaults to 40000.
149147
random_seed: Random seed used for reproducibility. Defaults to 42.
150148
"""
151-
152149
# Original Ensemble attack samples 40k data points to construct
153150
# 1) the main population (real data) used for training the synthetic data generator model,
154151
# 2) evaluation that is the meta train data used to train the meta classifier,
155152
# 3) test to evaluate the meta classifier.
156153

157-
df_real_data = all_population_data.sample(
158-
n=num_total_samples, random_state=random_seed
159-
)
154+
df_real_data = all_population_data.sample(n=num_total_samples, random_state=random_seed)
160155

161156
# Split the data. df_real_train is used for training the synthetic data generator model.
162157
df_real_train, df_real_val, df_real_test = split_real_data(
163158
df_real_data,
164-
column_to_stratify=column_to_stratify,
159+
column_to_stratify=column_to_stratify,
165160
random_seed=random_seed,
166161
)
167162
# Generate validation and test sets with labels. Validation is used for training the meta classifier
@@ -172,9 +167,7 @@ def process_split_data(
172167
df_real_train,
173168
df_real_val,
174169
df_real_test,
175-
stratify=df_real_train[
176-
column_to_stratify
177-
], # TODO: This value is not documented in the original codebase.
170+
stratify=df_real_train[column_to_stratify], # TODO: This value is not documented in the original codebase.
178171
random_seed=random_seed,
179172
)
180173

src/midst_toolkit/attacks/ensemble/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas as pd
55

6-
from midst_toolkit.common.logger import log
6+
from src.midst_toolkit.common.logger import log
77

88

99
def save_dataframe(df: pd.DataFrame, file_path: Path, file_name: str) -> None:
@@ -14,11 +14,10 @@ def save_dataframe(df: pd.DataFrame, file_path: Path, file_name: str) -> None:
1414
df: DataFrame to be saved.
1515
file_path: Path where the file will be saved.
1616
file_name: Name of the file to save the DataFrame as.
17-
17+
1818
Returns:
1919
None
2020
"""
21-
2221
assert Path.exists(file_path), f"Path {file_path} does not exist."
2322
df.to_csv(file_path / file_name, index=False)
2423
log(INFO, f"DataFrame saved to {file_path / file_name}")
@@ -33,10 +32,10 @@ def load_dataframe(file_path: Path, file_name: str) -> pd.DataFrame:
3332
file_name: Name of the file to load the DataFrame from.
3433
3534
Returns:
36-
pd.DataFrame: Loaded dataframe.
35+
Loaded dataframe.
3736
"""
3837
full_path = file_path / file_name
3938
assert Path.exists(full_path), f"File {full_path} does not exist."
4039
df = pd.read_csv(full_path)
4140
log(INFO, f"DataFrame loaded from {full_path}")
42-
return df
41+
return df

tests/unit/attacks/ensemble/assets/population_data/all_population.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ trans_id,account_id,trans_date,trans_type,operation,amount,balance,k_symbol,bank
9797
97776,334,752,2,4,1680.0,28091.8,1,0,0
9898
1177053,4035,1135,2,4,6300.0,38763.4,1,0,0
9999
980869,3347,1923,2,1,2624.0,10827.7,5,10,24763751
100-
712969,2436,1144,2,4,2040.0,42428.1,1,0,0
100+
712969,2436,1144,2,4,2040.0,42428.1,1,0,0

tests/unit/attacks/ensemble/test_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ data_processing_config:
1515
population_sample_size: 80
1616

1717
# General settings
18-
random_seed: 42
18+
random_seed: 42

tests/unit/attacks/ensemble/test_process_data_split.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from pathlib import Path
2-
from omegaconf import DictConfig
2+
33
import pytest
4-
from hydra import initialize, compose
4+
from hydra import compose, initialize
55
from omegaconf import DictConfig
6-
from src.midst_toolkit.attacks.ensemble.utils import load_dataframe
6+
77
from src.midst_toolkit.attacks.ensemble.process_split_data import process_split_data
8+
from src.midst_toolkit.attacks.ensemble.utils import load_dataframe
89

910

1011
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)