11from logging import INFO
22from pathlib import Path
3+
34import numpy as np
45import pandas as pd
56from sklearn .model_selection import train_test_split
67
7- from midst_toolkit .attacks .ensemble .utils import (
8+ from midst_toolkit .common .logger import log
9+ from src .midst_toolkit .attacks .ensemble .utils import (
810 save_dataframe ,
911)
1012
11- from midst_toolkit .common .logger import log
12-
1313
1414def split_real_data (
1515 df_real : pd .DataFrame ,
@@ -24,6 +24,7 @@ def split_real_data(
2424 column_to_stratify: Column name to use for stratified splitting.
2525 proportion: Proportions for train and validation splits.
2626 random_seed: Random seed for reproducibility.
27+
2728 Returns:
2829 A tuple containing the train, validation, and test dataframes.
2930 """
@@ -41,8 +42,7 @@ def split_real_data(
4142 # Further split the control into val and test set:
4243 df_real_val , df_real_test = train_test_split (
4344 df_real_control ,
44- test_size = (1 - proportion ["train" ] - proportion ["val" ])
45- / (1 - proportion ["train" ]),
45+ test_size = (1 - proportion ["train" ] - proportion ["val" ]) / (1 - proportion ["train" ]),
4646 random_state = random_seed ,
4747 stratify = df_real_control [column_to_stratify ],
4848 )
@@ -121,9 +121,7 @@ def generate_val_test(
121121 ignore_index = True ,
122122 )
123123
124- df_test = df_test .sample (frac = 1 , random_state = random_seed ).reset_index (
125- drop = True
126- )
124+ df_test = df_test .sample (frac = 1 , random_state = random_seed ).reset_index (drop = True )
127125
128126 y_test = df_test ["is_train" ].values
129127 df_test = df_test .drop (columns = ["is_train" ])
@@ -148,20 +146,17 @@ def process_split_data(
148146 num_total_samples: Number os samples that I randomly selected from the population. Defaults to 40000.
149147 random_seed: Random seed used for reproducibility. Defaults to 42.
150148 """
151-
152149 # Original Ensemble attack samples 40k data points to construct
153150 # 1) the main population (real data) used for training the synthetic data generator model,
154151 # 2) evaluation that is the meta train data used to train the meta classifier,
155152 # 3) test to evaluate the meta classifier.
156153
157- df_real_data = all_population_data .sample (
158- n = num_total_samples , random_state = random_seed
159- )
154+ df_real_data = all_population_data .sample (n = num_total_samples , random_state = random_seed )
160155
161156 # Split the data. df_real_train is used for training the synthetic data generator model.
162157 df_real_train , df_real_val , df_real_test = split_real_data (
163158 df_real_data ,
164- column_to_stratify = column_to_stratify ,
159+ column_to_stratify = column_to_stratify ,
165160 random_seed = random_seed ,
166161 )
167162 # Generate validation and test sets with labels. Validation is used for training the meta classifier
@@ -172,9 +167,7 @@ def process_split_data(
172167 df_real_train ,
173168 df_real_val ,
174169 df_real_test ,
175- stratify = df_real_train [
176- column_to_stratify
177- ], # TODO: This value is not documented in the original codebase.
170+ stratify = df_real_train [column_to_stratify ], # TODO: This value is not documented in the original codebase.
178171 random_seed = random_seed ,
179172 )
180173
0 commit comments