Skip to content

Commit 6a6b99d

Browse files
committed
Fixed docstrings
1 parent ba22d92 commit 6a6b99d

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

examples/ensemble_attack_example/real_data_collection.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def collect_midst_attack_data(
3030
Collect the real data in a specific setting of the provided MIDST challenge resources.
3131
3232
Args:
33-
attack_type (str): The attack setting.
34-
data_dir (Path): The path where the data is stored.
35-
data_split (str): Indicates if this is train, dev, or final data.
36-
dataset (str): The dataset to be collected. Either "train" or "challenge".
37-
data_config (dict): Configuration dictionary containing data paths and file names.
33+
attack_type: The attack setting.
34+
data_dir: The path where the data is stored.
35+
data_split: Indicates if this is train, dev, or final data.
36+
dataset: The dataset to be collected. Either "train" or "challenge".
37+
data_processing_config: Configuration dictionary containing data specific information.
3838
3939
Returns:
4040
pd.DataFrame: The specified dataset in this setting.
@@ -77,21 +77,22 @@ def collect_midst_data(
7777
attack_types: list[str],
7878
data_splits: list[str],
7979
dataset: str,
80-
data_config: DictConfig,
80+
data_processing_config: DictConfig,
8181
) -> pd.DataFrame:
8282
"""
8383
Collect train or challenge data of the specified attack type from the provided data folders
8484
in the MIDST competition.
8585
8686
Args:
87-
attack_types (list[str]): List of attack names to be collected.
88-
data_splits (list[str]): A list indicating the data split to be collected.
87+
midst_data_input_dir: The path where the MIDST data folders are stored.
88+
attack_types: List of attack names for data collection.
89+
data_splits: A list indicating the data split to be collected.
8990
Could be any of train, dev, or final data splits.
90-
dataset (str): The dataset to be collected. Either "train" or "challenge".
91-
data_config (dict): Configuration dictionary containing data paths and file names.
91+
dataset: The dataset to be collected. Either "train" or "challenge".
92+
data_processing_config: Configuration dictionary containing data paths and file names.
9293
9394
Returns:
94-
pd.DataFrame: Collected train or challenge data as a DataFrame.
95+
Collected train or challenge data as a dataframe.
9596
"""
9697
assert dataset in [
9798
"train",
@@ -105,7 +106,7 @@ def collect_midst_data(
105106
data_dir=midst_data_input_dir,
106107
data_split=data_split,
107108
dataset=dataset,
108-
data_processing_config=data_config,
109+
data_processing_config=data_processing_config,
109110
)
110111

111112
population.append(df_real)
@@ -119,19 +120,19 @@ def collect_population_data_ensemble(
119120
save_dir: Path,
120121
) -> pd.DataFrame:
121122
"""
122-
Collect the population data from the MIDST competition based on ensemble mia implementation.
123+
Collect the population data from the MIDST competition based on Ensemble Attack implementation.
123124
Returns real data population that consists of the train data of all the attacks
124125
(black box and white box), and challenge points from train, dev and final of
125126
"tabddpm_black_box" attack. The population data is saved in the provided path,
126127
and returned as a dataframe.
127128
128129
Args:
129-
data_config (dict): Configuration dictionary containing data paths and file names.
130-
attack_types (list[str] | None): List of attack names to be collected.
131-
If None, all the attacks are collected based on ensemble mia implementation.
130+
midst_data_input_dir: The path where the MIDST data folders are stored.
131+
data_processing_config: Configuration dictionary containing data information and file names.
132+
save_dir: The path where the collected population data should be saved.
132133
133134
Returns:
134-
pd.DataFrame: The collected population data.
135+
The collected population data as a dataframe.
135136
"""
136137

137138
# Ensemble Attack collects train data of all the attack types (back box and white box)
@@ -141,7 +142,7 @@ def collect_population_data_ensemble(
141142
attack_types,
142143
data_splits=["train"],
143144
dataset="train",
144-
data_config=data_processing_config,
145+
data_processing_config=data_processing_config,
145146
)
146147
# Drop ids.
147148
df_population_no_id = df_population.drop(columns=["trans_id", "account_id"])
@@ -156,7 +157,7 @@ def collect_population_data_ensemble(
156157
attack_types=challenge_attack_types,
157158
data_splits=["train", "dev", "final"],
158159
dataset="challenge",
159-
data_config=data_processing_config,
160+
data_cdata_processing_configonfig=data_processing_config,
160161
)
161162
# Save the challenge points
162163
save_dataframe(df_challenge, save_dir, "challenge_points_all.csv")

src/midst_toolkit/attacks/ensemble/process_split_data.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ def split_real_data(
2020
"""Splits a real dataset into train, validation, and test sets, saves them as CSV files, and returns the splits.
2121
2222
Args:
23-
df_real (pd.DataFrame): The input real dataset to be split.
24-
column_to_stratify (str, optional): Column name to use for stratified splitting. Defaults to None.
25-
proportion (dict, optional): Proportions for train and validation splits.
26-
random_seed (int, optional): Random seed for reproducibility. Defaults to None.
27-
23+
df_real: The input real dataset to be split.
24+
column_to_stratify: Column name to use for stratified splitting.
25+
proportion: Proportions for train and validation splits.
26+
random_seed: Random seed for reproducibility.
2827
Returns:
29-
Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing the train, validation, and test DataFrames.
28+
A tuple containing the train, validation, and test dataframes.
3029
"""
3130
if proportion is None:
3231
proportion = {"train": 0.5, "val": 0.25}
@@ -67,14 +66,14 @@ def generate_val_test(
6766
The resulting validation and test sets are used for meta classifier training and evaluation, respectively.
6867
6968
Args:
70-
df_real_train (pd.DataFrame): Real training data.
71-
df_real_control_val (pd.DataFrame): Real control data for validation.
72-
df_real_control_test (pd.DataFrame): Real control data for final evaluation.
73-
stratify (pd.Series): Series used to stratify the real training data.
74-
random_seed (int): Random seed for reproducibility.
69+
df_real_train: Real training data.
70+
df_real_control_val: Real control data for validation.
71+
df_real_control_test: Real control data for final evaluation.
72+
stratify: Series used to stratify the real training data.
73+
random_seed: Random seed for reproducibility.
7574
7675
Returns:
77-
Tuple[pd.DataFrame, np.ndarray, pd.DataFrame, np.ndarray]: Features and labels for validation and test sets.
76+
Features and labels for validation and test sets, respectively.
7877
"""
7978
df_real_train["stratify"] = stratify
8079

@@ -141,6 +140,13 @@ def process_split_data(
141140
) -> None:
142141
"""
143142
Splits the data into train, validation, and test sets according to the attack design.
143+
144+
Args:
145+
all_population_data: The total population data that the attacker has access to as a DataFrame.
146+
processed_attack_data_path: Path where the processed attack data will be saved.
147+
column_to_stratify: Column name to use for stratified splitting.
148+
num_total_samples: Number os samples that I randomly selected from the population. Defaults to 40000.
149+
random_seed: Random seed used for reproducibility. Defaults to 42.
144150
"""
145151

146152
# Original Ensemble attack samples 40k data points to construct

src/midst_toolkit/attacks/ensemble/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ def load_dataframe(file_path: Path, file_name: str) -> pd.DataFrame:
2929
Load a DataFrame from a CSV file.
3030
3131
Args:
32-
file_path (str): Path where the file is stored.
33-
file_name (str): Name of the file to load the DataFrame from.
32+
file_path: Path where the file is stored.
33+
file_name: Name of the file to load the DataFrame from.
3434
3535
Returns:
36-
pd.DataFrame: Loaded DataFrame.
36+
pd.DataFrame: Loaded dataframe.
3737
"""
3838
full_path = file_path / file_name
3939
assert Path.exists(full_path), f"File {full_path} does not exist."

0 commit comments

Comments
 (0)