Skip to content

Commit be2ac7d

Browse files
committed
Improved comments and function name
1 parent 5516989 commit be2ac7d

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/midst_toolkit/attacks/ensemble/process_split_data.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def split_real_data(
6464
)
6565

6666

67-
def generate_val_test(
67+
def generate_train_test_challenge_splits(
6868
df_real_train: pd.DataFrame,
6969
df_real_control_val: pd.DataFrame,
7070
df_real_control_test: pd.DataFrame,
@@ -114,7 +114,10 @@ def generate_val_test(
114114
df_val = df_val.drop(columns=["is_train"])
115115

116116
# Test set
117-
# df_temp will be assigned as our test set if it has the same size as df_real_control_test.
117+
# `df_temp` will be assigned as our test set if it has the same size as `df_real_control_test`,
118+
# otherwise, we further split `df_temp` to get a test set of the same size as `df_real_control_test`.
119+
# This is because we want to take a train split of same size as `df_real_control_test` to ensure
120+
# balanced classes in the final test set.
118121
if len(df_temp) == len(df_real_control_test):
119122
df_real_train_test = df_temp
120123
else:
@@ -161,22 +164,28 @@ def process_split_data(
161164
"""
162165
# Original Ensemble attack samples 40k data points to construct
163166
# 1) the main population (real data) used for training the synthetic data generator model,
164-
# 2) evaluation that is the meta train data used to train the meta classifier,
165-
# 3) test to evaluate the meta classifier.
167+
# 2) evaluation that is the meta train data (membership classification train dataset) used to train
168+
# the meta classifier,
169+
# 3) test (membership classification test dataset) to evaluate the meta classifier.
166170

167171
df_real_data = all_population_data.sample(n=num_total_samples, random_state=random_seed)
168172

169-
# Split the data. df_real_train is used for training the synthetic data generator model.
173+
# `df_real_train` is used for training the synthetic data generator model.
170174
df_real_train, df_real_val, df_real_test = split_real_data(
171175
df_real_data,
172176
column_to_stratify=column_to_stratify,
173177
random_seed=random_seed,
174178
)
175-
# Generate validation and test sets with labels. Validation is used for training the meta classifier
176-
# and test is used for meta classifier evaluation.
177-
# Half of the df_real_train will be assigned to validation and the other half to test with
179+
# Generate challenge datasets:
180+
# `df_val` is used for training the meta classifier (membership classification train dataset).
181+
# and `df_test` is used for meta classifier evaluation (membership classification test dataset).
182+
# A part of the `df_real_train` will be assigned to `df_val` and a another part to `df_test` with
178183
# their "is_train" column set to 1 meaning that these samples are in the models training corpus.
179-
df_val, y_val, df_test, y_test = generate_val_test(
184+
# Because `df_real_train` will be used to train a synthetic model, we're including some of it in
185+
# `df_val` and `df_test` sets to create the challenges assuming the `df_real_val` and `df_real_test`
186+
# data will not be part of the training data.
187+
# This code makes sure `is_train` classes are balanced in the challenge datasets.
188+
df_val, y_val, df_test, y_test = generate_train_test_challenge_splits(
180189
df_real_train,
181190
df_real_val,
182191
df_real_test,

0 commit comments

Comments
 (0)