@@ -64,7 +64,7 @@ def split_real_data(
6464 )
6565
6666
67- def generate_val_test (
67+ def generate_train_test_challenge_splits (
6868 df_real_train : pd .DataFrame ,
6969 df_real_control_val : pd .DataFrame ,
7070 df_real_control_test : pd .DataFrame ,
@@ -114,7 +114,10 @@ def generate_val_test(
114114 df_val = df_val .drop (columns = ["is_train" ])
115115
116116 # Test set
117- # df_temp will be assigned as our test set if it has the same size as df_real_control_test.
117+ # `df_temp` will be assigned as our test set if it has the same size as `df_real_control_test`,
118+ # otherwise, we further split `df_temp` to get a test set of the same size as `df_real_control_test`.
119+ # This is because we want to take a train split of same size as `df_real_control_test` to ensure
120+ # balanced classes in the final test set.
118121 if len (df_temp ) == len (df_real_control_test ):
119122 df_real_train_test = df_temp
120123 else :
@@ -161,22 +164,28 @@ def process_split_data(
161164 """
162165 # Original Ensemble attack samples 40k data points to construct
163166 # 1) the main population (real data) used for training the synthetic data generator model,
164- # 2) evaluation that is the meta train data used to train the meta classifier,
165- # 3) test to evaluate the meta classifier.
167+ # 2) evaluation that is the meta train data (membership classification train dataset) used to train
168+ # the meta classifier,
169+ # 3) test (membership classification test dataset) to evaluate the meta classifier.
166170
167171 df_real_data = all_population_data .sample (n = num_total_samples , random_state = random_seed )
168172
169- # Split the data. df_real_train is used for training the synthetic data generator model.
173+ # ` df_real_train` is used for training the synthetic data generator model.
170174 df_real_train , df_real_val , df_real_test = split_real_data (
171175 df_real_data ,
172176 column_to_stratify = column_to_stratify ,
173177 random_seed = random_seed ,
174178 )
175- # Generate validation and test sets with labels. Validation is used for training the meta classifier
176- # and test is used for meta classifier evaluation.
177- # Half of the df_real_train will be assigned to validation and the other half to test with
179+ # Generate challenge datasets:
180+ # `df_val` is used for training the meta classifier (membership classification train dataset).
181+ # and `df_test` is used for meta classifier evaluation (membership classification test dataset).
182+ # A part of the `df_real_train` will be assigned to `df_val` and a another part to `df_test` with
178183 # their "is_train" column set to 1 meaning that these samples are in the models training corpus.
179- df_val , y_val , df_test , y_test = generate_val_test (
184+ # Because `df_real_train` will be used to train a synthetic model, we're including some of it in
185+ # `df_val` and `df_test` sets to create the challenges assuming the `df_real_val` and `df_real_test`
186+ # data will not be part of the training data.
187+ # This code makes sure `is_train` classes are balanced in the challenge datasets.
188+ df_val , y_val , df_test , y_test = generate_train_test_challenge_splits (
180189 df_real_train ,
181190 df_real_val ,
182191 df_real_test ,
0 commit comments