@@ -30,11 +30,11 @@ def collect_midst_attack_data(
3030 Collect the real data in a specific setting of the provided MIDST challenge resources.
3131
3232 Args:
33- attack_type (str) : The attack setting.
34- data_dir (Path) : The path where the data is stored.
35- data_split (str) : Indicates if this is train, dev, or final data.
36- dataset (str) : The dataset to be collected. Either "train" or "challenge".
37- data_config (dict) : Configuration dictionary containing data paths and file names .
33+ attack_type: The attack setting.
34+ data_dir: The path where the data is stored.
35+ data_split: Indicates if this is train, dev, or final data.
36+ dataset: The dataset to be collected. Either "train" or "challenge".
37+ data_processing_config : Configuration dictionary containing data specific information .
3838
3939 Returns:
4040 pd.DataFrame: The specified dataset in this setting.
@@ -77,21 +77,22 @@ def collect_midst_data(
7777 attack_types : list [str ],
7878 data_splits : list [str ],
7979 dataset : str ,
80- data_config : DictConfig ,
80+ data_processing_config : DictConfig ,
8181) -> pd .DataFrame :
8282 """
8383 Collect train or challenge data of the specified attack type from the provided data folders
8484 in the MIDST competition.
8585
8686 Args:
87- attack_types (list[str]): List of attack names to be collected.
88- data_splits (list[str]): A list indicating the data split to be collected.
87+ midst_data_input_dir: The path where the MIDST data folders are stored.
88+ attack_types: List of attack names for data collection.
89+ data_splits: A list indicating the data split to be collected.
8990 Could be any of train, dev, or final data splits.
90- dataset (str) : The dataset to be collected. Either "train" or "challenge".
91- data_config (dict) : Configuration dictionary containing data paths and file names.
91+ dataset: The dataset to be collected. Either "train" or "challenge".
92+ data_processing_config : Configuration dictionary containing data paths and file names.
9293
9394 Returns:
94- pd.DataFrame: Collected train or challenge data as a DataFrame .
95+ Collected train or challenge data as a dataframe .
9596 """
9697 assert dataset in [
9798 "train" ,
@@ -105,7 +106,7 @@ def collect_midst_data(
105106 data_dir = midst_data_input_dir ,
106107 data_split = data_split ,
107108 dataset = dataset ,
108- data_processing_config = data_config ,
109+ data_processing_config = data_processing_config ,
109110 )
110111
111112 population .append (df_real )
@@ -119,19 +120,19 @@ def collect_population_data_ensemble(
119120 save_dir : Path ,
120121) -> pd .DataFrame :
121122 """
122- Collect the population data from the MIDST competition based on ensemble mia implementation.
123+ Collect the population data from the MIDST competition based on Ensemble Attack implementation.
123124 Returns real data population that consists of the train data of all the attacks
124125 (black box and white box), and challenge points from train, dev and final of
125126 "tabddpm_black_box" attack. The population data is saved in the provided path,
126127 and returned as a dataframe.
127128
128129 Args:
129- data_config (dict): Configuration dictionary containing data paths and file names .
130- attack_types (list[str] | None): List of attack names to be collected .
131- If None, all the attacks are collected based on ensemble mia implementation .
130+ midst_data_input_dir: The path where the MIDST data folders are stored .
131+ data_processing_config: Configuration dictionary containing data information and file names .
132+ save_dir: The path where the collected population data should be saved .
132133
133134 Returns:
134- pd.DataFrame: The collected population data.
135+ The collected population data as a dataframe .
135136 """
136137
137138 # Ensemble Attack collects train data of all the attack types (back box and white box)
@@ -141,7 +142,7 @@ def collect_population_data_ensemble(
141142 attack_types ,
142143 data_splits = ["train" ],
143144 dataset = "train" ,
144- data_config = data_processing_config ,
145+ data_processing_config = data_processing_config ,
145146 )
146147 # Drop ids.
147148 df_population_no_id = df_population .drop (columns = ["trans_id" , "account_id" ])
@@ -156,7 +157,7 @@ def collect_population_data_ensemble(
156157 attack_types = challenge_attack_types ,
157158 data_splits = ["train" , "dev" , "final" ],
158159 dataset = "challenge" ,
159- data_config = data_processing_config ,
160+ data_cdata_processing_configonfig = data_processing_config ,
160161 )
161162 # Save the challenge points
162163 save_dataframe (df_challenge , save_dir , "challenge_points_all.csv" )
0 commit comments