Skip to content
Open
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
27fa9c8
Fixed 2 bugs: shadow synth data size, and var name
fatemetkl Nov 12, 2025
f64b650
Remove dependency on target’s training result object; attack now only…
fatemetkl Nov 18, 2025
0d4a3b6
Added test script that works with the trained tabddpm models in diffe…
fatemetkl Nov 18, 2025
9c1217d
Updated test
fatemetkl Nov 18, 2025
0a12084
pre-commit checks
fatemetkl Nov 18, 2025
81c0bb3
Merged main into this branch, addressed conflicts
fatemetkl Nov 18, 2025
a19a595
Minor fixes
fatemetkl Nov 18, 2025
e93b3c1
Minor fixes
fatemetkl Nov 18, 2025
70cb1af
Removed extra line
fatemetkl Nov 18, 2025
942475d
Sara's comments
fatemetkl Nov 19, 2025
de1cb95
Addressed Marcelo's comments
fatemetkl Nov 20, 2025
819826c
Merged remote
fatemetkl Nov 20, 2025
59e52d4
Ensemble experiments: SLURM script (#97)
fatemetkl Nov 20, 2025
cb7e65c
Small fix
fatemetkl Nov 20, 2025
1b76b08
Added the success calculation script
fatemetkl Nov 21, 2025
3e301f2
Finalized the script
fatemetkl Nov 21, 2025
9545f4a
David Comments first pass
fatemetkl Nov 25, 2025
9c750ae
Removed extra comment
fatemetkl Nov 25, 2025
e79f9cc
Just saving experiment configs for my own reference
fatemetkl Nov 25, 2025
ecf22ab
Latest changes
fatemetkl Dec 2, 2025
22c028c
Fixed population dataset for other settings
fatemetkl Dec 3, 2025
19b67e8
Fix test shadow train data
fatemetkl Dec 4, 2025
4977d1e
Avoid saving all TrainingResult to reduce memory footprint
fatemetkl Dec 8, 2025
8bd62ee
Updated test
fatemetkl Dec 16, 2025
bcea5d0
Cleaned the code
fatemetkl Dec 17, 2025
f691a8b
Added experiment scripts
fatemetkl Dec 17, 2025
4956fd4
Merge branch 'main' into ft/ensemble_changes and resolved conflicts
fatemetkl Dec 17, 2025
4596949
Removed extra experiment scripts
fatemetkl Dec 17, 2025
44b4164
Fixed typing issues
fatemetkl Dec 23, 2025
eeef708
Fixed unit tests
fatemetkl Jan 7, 2026
f8962f8
Fixed integration tests as shadow model training now only saves synth…
fatemetkl Jan 7, 2026
37323c6
Fixed mypy errors
fatemetkl Jan 7, 2026
ee4b659
mypy and ruff fixes
fatemetkl Jan 8, 2026
be7e494
Merge branch 'main' into ft/ensemble_changes
fatemetkl Jan 8, 2026
eea6ea5
Small change
fatemetkl Jan 8, 2026
d598206
extra line
fatemetkl Jan 8, 2026
d14b22f
Merged main into the brand
fatemetkl Jan 8, 2026
161c71b
Merge remote-tracking branch 'origin/main' into ft/ensemble_changes
fatemetkl Jan 8, 2026
9afda5d
End of file extra line
fatemetkl Jan 8, 2026
72c0e90
Improved documentation a bit
fatemetkl Jan 9, 2026
34d81ba
Added more documentations
fatemetkl Jan 12, 2026
01572d3
Cleaning
fatemetkl Jan 12, 2026
951abfb
Merge remote-tracking branch 'origin/main' into ft/ensemble_changes
fatemetkl Jan 12, 2026
720ddcf
coderabbitai comments
fatemetkl Jan 12, 2026
b87bd57
Merge branch 'main' into ft/ensemble_changes
fatemetkl Jan 13, 2026
7a3810a
Merge branch 'main' into ft/ensemble_changes
emersodb Jan 21, 2026
90acab1
Merge branch 'main' into ft/ensemble_changes
emersodb Jan 26, 2026
d64b908
addressed David's comments
fatemetkl Jan 27, 2026
c283d7a
Merge branch 'ft/ensemble_changes' of https://github.com/VectorInstit…
fatemetkl Jan 27, 2026
00f1e39
minor bug
fatemetkl Jan 27, 2026
fd2d622
Merge remote-tracking branch 'origin/main' into ft/ensemble_changes
fatemetkl Jan 27, 2026
00e98a0
Final comments and cleaning
fatemetkl Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions examples/ensemble_attack/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
# Ensemble experiment configuration
Copy link
Collaborator Author

@fatemetkl fatemetkl Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current single config is hard to understand because it mixes many variables and data paths with unclear names inherited from the original attack code. Splitting it into multiple pipeline‑specific configs would improve clarity and maintainability, even if it adds some overhead. Alternatively, improving variable naming within one config could be helpful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps (if you haven't already) you could create a clickup ticket with this in there as a next step?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing both (splitting and better naming) would be a worthwhile endeavor.

# This config can be used to run both the Ensemble attack training (``run_attack.py``) and testing phases (``tets_attack_model.py``).
base_experiment_dir: examples/ensemble_attack/tabddpm_20k_experiment_data # Processed data, and experiment artifacts will be stored here
base_data_config_dir: examples/ensemble_attack/data_configs # Training and data type configs are saved under this directory
base_experiment_dir: /projects/midst-experiments/ensemble_attack/tabddpm_10k_experiment_data/10k/ # Processed data, and experiment artifacts will be stored under this directory.
base_data_config_dir: examples/ensemble_attack/data_configs # Training and data type configs are saved under this directory.

# Pipeline control
# Training Pipeline Control
pipeline:
run_data_processing: true # Set this to false if you have already saved the processed data
run_shadow_model_training: true # Set this to false if shadow models are already trained and saved
run_metaclassifier_training: true

target_model: # This is only used for testing the attack on a real target model.
# This is for models trained on 20k data and generating 20k synthetic data
target_model_directory: /projects/midst-experiments/all_tabddpms/tabddpm_trained_with_20k/train/
target_model_directory: /projects/midst-experiments/all_tabddpms/tabddpm_trained_with_10k/test/
target_model_id: 21 # Will be overridden per SLURM array task
target_model_name: tabddpm_${target_model.target_model_id}
target_synthetic_data_path: ${target_model.target_model_directory}/${target_model.target_model_name}/synthetic_data/20k/20k.csv
target_synthetic_data_path: ${target_model.target_model_directory}/${target_model.target_model_name}/synthetic_data/10k/10k.csv
challenge_data_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_with_id.csv
challenge_label_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_label.csv

target_attack_artifact_dir: ${base_experiment_dir}/target_${target_model.target_model_id}_attack_artifacts/
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This directory was extra and can be removed.

attack_probabilities_result_path: ${target_model.target_attack_artifact_dir}/attack_model_${target_model.target_model_id}_proba
target_shadow_models_output_path: ${target_model.target_attack_artifact_dir}/tabddpm_${target_model.target_model_id}_shadows_dir
target_shadow_models_output_path: ${base_experiment_dir}/test_all_targets # Sub-directory to store test shadows and results
attack_probabilities_result_path: ${target_model.target_shadow_models_output_path}/test_probabilities/attack_model_${target_model.target_model_id}_proba
attack_rmia_shadow_training_data_choice: "combined" # Options: "combined", "only_challenge", "only_train". This determines which data to use for training RMIA attack model in testing phase.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new config variable. You can read more about the options in select_challenge_data_for_training()'s docstring.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe include something like See select_challenge_data_for_training()'s docstring in the comment to direct a user to this as well?



# Data paths
data_paths:
midst_data_path: /projects/midst-experiments/all_tabddpms # Used to collect the data
population_path: ${base_experiment_dir}/population_data # Path where the collected population data will be stored
processed_attack_data_path: ${base_experiment_dir}/attack_data # Path where the processed attack real train and evaluation data is stored
attack_evaluation_result_path: ${base_experiment_dir}/evaluation_results # Path where the attack evaluation results will be stored
midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input) as defined in data_processing_config
processed_base_data_dir: ${base_experiment_dir} # To save new processed data for training, or read from previously collected and processed data (testing phase).
population_path: ${data_paths.processed_base_data_dir}/population_data # Path where the collected population data will be stored (output/input)
processed_attack_data_path: ${data_paths.processed_base_data_dir}/attack_data # Path where the processed attack real train and evaluation data is stored (output/input)
attack_evaluation_result_path: ${base_experiment_dir}/evaluation_results # Path where the attack (train phase) evaluation results will be stored (output)


model_paths:
metaclassifier_model_path: ${base_experiment_dir}/trained_models # Path where the trained metaclassifier model will be saved
Expand All @@ -38,23 +39,25 @@ model_paths:
data_processing_config:
population_attack_data_types_to_collect:
[
"tabddpm_trained_with_20k",
"tabddpm_trained_with_10k",
]
challenge_attack_data_types_to_collect:
[
"tabddpm_trained_with_20k",
"tabddpm_trained_with_10k",
]
population_splits: ["train"] # Data splits to be collected for population data
challenge_splits: ["train"] # Data splits to be collected for challenge points
challenge_splits: ["train" , "test"] # Data splits to be collected for challenge points
original_population_data_path: /projects/midst-experiments/ensemble_attack/competition/population_data/ #Attack's collected population for DOMIAS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but space after the #

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this where the population will be written, read from, or both?

Copy link
Collaborator Author

@fatemetkl fatemetkl Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to the comment that it is where the population will be read from.

# The column name in the data to be used for stratified splitting.
column_to_stratify: "trans_type" # Attention: This value is not documented in the original codebase.
folder_ranges: #Specify folder ranges for any of the mentioned splits.
train: [[1, 20]] # Folders to be used for train data collection in the experiments
train: [[1, 21]] # Folders to be used for train data collection in the experiments
test: [[21, 31] , [31, 41]]
# File names in MIDST data directories.
single_table_train_data_file_name: "train_with_id.csv"
multi_table_train_data_file_name: "trans.csv"
challenge_data_file_name: "challenge_with_id.csv"
population_sample_size: 40000 # Population size is the total data that your attack has access to.
population_sample_size: 20000 # Population size is the total data that your attack has access to.
# In experiments, this is sampled out of all the collected training data in case the available data
# is more than this number. Note that, half of this data is actually used for training, the other half
# is used for evaluation. For example, with 40k population size, only 20k is used for training the attack model.
Expand Down Expand Up @@ -86,7 +89,7 @@ shadow_training:
fine_tune_diffusion_iterations: 200000 # Original code: 200000
fine_tune_classifier_iterations: 20000 # Original code: 20000
pre_train_data_size: 60000 # Original code: 60000
number_of_points_to_synthesize: 20000 # Number of synthetic data samples to be generated by shadow models.
number_of_points_to_synthesize: 10000 # Number of synthetic data samples to be generated by shadow models.
# Original code: 20000


Expand All @@ -104,7 +107,7 @@ metaclassifier:
meta_classifier_model_name: ${metaclassifier.model_type}_metaclassifier_model

attack_success_computation:
target_ids_to_test: [21,22,23] # List of target model IDs to compute the attack success for.
target_ids_to_test: [21, 22, 23, 24, 25, 26, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # List of target model IDs to compute the attack success for.

# General settings
random_seed: 42 # Set to null for no seed, or an integer for a fixed seed
42 changes: 35 additions & 7 deletions examples/ensemble_attack/real_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"""

from enum import Enum
from logging import INFO
from pathlib import Path

import pandas as pd
from omegaconf import DictConfig

from midst_toolkit.attacks.ensemble.data_utils import load_dataframe, save_dataframe
from midst_toolkit.common.logger import log


class AttackType(Enum):
Expand Down Expand Up @@ -66,11 +68,11 @@ def collect_midst_attack_data(
Returns:
pd.DataFrame: The specified dataset in this setting.
"""
assert data_split in [
"train",
"dev",
"final",
], "data_split should be one of 'train', 'dev', or 'final'."
# assert data_split in [
# "train",
# "dev",
# "final",
# ], "data_split should be one of 'train', 'dev', or 'final'."
# `data_id` is the folder numbering of each training or challenge dataset,
# and is defined with the provided config.
data_id = expand_ranges(data_processing_config.folder_ranges[data_split])
Expand Down Expand Up @@ -133,7 +135,7 @@ def collect_midst_data(
data_processing_config=data_processing_config,
)

population.append(df_real)
population.append(df_real)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug! Thank you for catching this, Sara!


return pd.concat(population).drop_duplicates()

Expand All @@ -142,6 +144,7 @@ def collect_population_data_ensemble(
midst_data_input_dir: Path,
data_processing_config: DictConfig,
save_dir: Path,
original_repo_population: pd.DataFrame,
population_splits: list[str] | None = None,
challenge_splits: list[str] | None = None,
) -> pd.DataFrame:
Expand All @@ -156,6 +159,13 @@ def collect_population_data_ensemble(
midst_data_input_dir: The path where the MIDST data folders are stored.
data_processing_config: Configuration dictionary containing data information and file names.
save_dir: The path where the collected population data should be saved.
original_repo_population: The original population data collected from the MIDST challenge repository.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but aren't we collecting data in this function? That is, what do we mean by original here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: is this from the MIDST challenge repository or the ensemble attack repository?

Copy link
Collaborator Author

@fatemetkl fatemetkl Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This points to the 800k dataset that was collected by the original attack implementation based on the available folders in the MIDST challenge repository.
So, perhaps, a better comment would be: "The original population data collected by the original attack implementation from the MIDST challenge repository."
Does this look better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup :)

population_splits: A list indicating the data splits to be collected for population data.
Could be any of `train`, `dev`, or `final` data splits. If None, the default list of ``["train"]``
is set in the function based on the original attack implementation.
challenge_splits: A list indicating the data splits to be collected for challenge points.
Could be any of `train`, `dev`, or `final` data splits. If None, the default list of
``["train", "dev", "final"]`` is set in the function based on the original attack implementation.
population_splits: A list indicating the data splits to be collected for population data.
Could be any of `train`, `dev`, or `final` data splits. If None, the default list of ``["train"]``
is set in the function based on the original attack implementation.
Expand All @@ -169,6 +179,15 @@ def collect_population_data_ensemble(
# Population data will be saved under ``save_dir``.
save_dir.mkdir(parents=True, exist_ok=True)

if population_splits is None:
population_splits = ["train"]
if challenge_splits is None:
# Original Ensemble collects all the challenge points from train, dev and final of "tabddpm_black_box" attack.
challenge_splits = ["train", "dev", "final"]

# Population data will be saved under ``save_dir``.
save_dir.mkdir(parents=True, exist_ok=True)

if population_splits is None:
population_splits = ["train"]
if challenge_splits is None:
Expand All @@ -180,19 +199,27 @@ def collect_population_data_ensemble(
# Provided attack name are valid based on AttackType enum
population_attack_types: list[AttackType] = [AttackType(attack_name) for attack_name in attack_names]

df_population = collect_midst_data(
df_population_experiment = collect_midst_data(
midst_data_input_dir,
population_attack_types,
data_splits=population_splits,
dataset="train",
data_processing_config=data_processing_config,
)

log(INFO, f"Collected experiment population data length before concatenation: {len(df_population_experiment)}")

df_population = pd.concat([df_population_experiment, original_repo_population]).drop_duplicates()
log(INFO, f"Concatenated population data length: {len(df_population)}")

# Drop ids.
df_population_no_id = df_population.drop(columns=["trans_id", "account_id"])
# Save the population data
save_dataframe(df_population, save_dir, "population_all.csv")
save_dataframe(df_population_no_id, save_dir, "population_all_no_id.csv")

challenge_attack_names = data_processing_config.challenge_attack_data_types_to_collect
challenge_attack_types = [AttackType(attack_name) for attack_name in challenge_attack_names]
challenge_attack_names = data_processing_config.challenge_attack_data_types_to_collect
challenge_attack_types = [AttackType(attack_name) for attack_name in challenge_attack_names]
df_challenge = collect_midst_data(
Expand All @@ -202,6 +229,7 @@ def collect_population_data_ensemble(
dataset="challenge",
data_processing_config=data_processing_config,
)
log(INFO, f"Collected challenge data length: {len(df_challenge)} from splits: {challenge_splits}")
# Save the challenge points
save_dataframe(df_challenge, save_dir, "challenge_points_all.csv")

Expand Down
15 changes: 14 additions & 1 deletion examples/ensemble_attack/run_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from omegaconf import DictConfig

from examples.ensemble_attack.real_data_collection import collect_population_data_ensemble
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
from midst_toolkit.attacks.ensemble.process_split_data import process_split_data
from midst_toolkit.common.logger import log
from midst_toolkit.common.random import set_all_random_seeds
Expand All @@ -23,15 +24,23 @@ def run_data_processing(config: DictConfig) -> None:
Args:
config: Configuration object set in config.yaml.
"""
# Load original repo's population

original_population_data = load_dataframe(
Path(config.data_processing_config.original_population_data_path),
"population_all_with_challenge.csv",
)
log(INFO, "Running data processing pipeline...")
# Collect the real data from the MIDST challenge resources.
population_data = collect_population_data_ensemble(
midst_data_input_dir=Path(config.data_paths.midst_data_path),
data_processing_config=config.data_processing_config,
save_dir=Path(config.data_paths.population_path),
original_repo_population=original_population_data,
population_splits=config.data_processing_config.population_splits,
challenge_splits=config.data_processing_config.challenge_splits,
)

# The following function saves the required dataframe splits in the specified processed_attack_data_path path.
process_split_data(
all_population_data=population_data,
Expand Down Expand Up @@ -67,7 +76,11 @@ def main(config: DictConfig) -> None:
# TODO: Investigate the source of error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO still a TODO?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer! That error seems to be fixed now.

if config.pipeline.run_shadow_model_training:
shadow_pipeline = importlib.import_module("examples.ensemble_attack.run_shadow_model_training")
shadow_data_paths = shadow_pipeline.run_shadow_model_training(config)
df_master_challenge_train = load_dataframe(
Path(config.data_paths.processed_attack_data_path),
"master_challenge_train.csv",
)
shadow_data_paths = shadow_pipeline.run_shadow_model_training(config, df_master_challenge_train)
shadow_data_paths = [Path(path) for path in shadow_data_paths]

target_model_synthetic_path = shadow_pipeline.run_target_model_training(config)
Expand Down
11 changes: 11 additions & 0 deletions examples/ensemble_attack/run_metaclassifier_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def run_metaclassifier_training(
The list should contain three paths, one for each set of shadow models.
target_model_synthetic_path: Path to the target model's synthetic data. This is all we need from a target
model to train the metaclassifier in the black-box setting.
target_model_synthetic_path: Path to the target model's synthetic data. This is all we need from a target
model to train the metaclassifier in the black-box setting.
"""
log(INFO, "Running metaclassifier training...")

Expand Down Expand Up @@ -63,6 +65,7 @@ def run_metaclassifier_training(
with open(model_path, "rb") as f:
shadow_data_and_result = pickle.load(f)
shadow_data_collection.append(shadow_data_and_result)
log(INFO, f"Shadow model data loaded from {model_path}.")

assert Path(target_model_synthetic_path).exists(), (
f"No file found at {target_model_synthetic_path}. "
Expand All @@ -71,6 +74,10 @@ def run_metaclassifier_training(

# Load the target model's synthetic data
target_synthetic_data = pd.read_csv(target_model_synthetic_path)
log(
INFO,
f"Target model's synthetic data loaded from {target_model_synthetic_path} with size {len(target_synthetic_data)}.",
)

assert target_synthetic_data is not None, "Target model's synthetic data is missing."
target_synthetic_data = target_synthetic_data.copy()
Expand All @@ -79,6 +86,10 @@ def run_metaclassifier_training(
Path(config.data_paths.population_path),
"population_all_with_challenge_no_id.csv",
)
log(
INFO,
f"Reference population data loaded from f{config.data_paths.population_path} with size {len(df_reference)}.",
)

# Extract trans_id from both train and test dataframes
assert "trans_id" in df_meta_train.columns, "Meta train data must have trans_id column"
Expand Down
14 changes: 6 additions & 8 deletions examples/ensemble_attack/run_shadow_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from logging import INFO
from pathlib import Path

import pandas as pd
from omegaconf import DictConfig

from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
Expand Down Expand Up @@ -79,12 +80,13 @@ def run_target_model_training(config: DictConfig) -> Path:
return target_model_synthetic_path


def run_shadow_model_training(config: DictConfig) -> list[Path]:
def run_shadow_model_training(config: DictConfig, df_challenge_train: pd.DataFrame) -> list[Path]:
"""
Function to run the shadow model training for RMIA attack.

Args:
config: Configuration object set in config.yaml.
df_challenge_train: DataFrame containing the data that is used to train RMIA shadow models.

Returns:
Paths to the saved shadow model results for the three sets of shadow models. For more details,
Expand All @@ -95,27 +97,23 @@ def run_shadow_model_training(config: DictConfig) -> list[Path]:
# Load the required dataframes for shadow model training.
# For shadow model training we need master_challenge_train and population data.
# Master challenge is the main training (or fine-tuning) data for the shadow models.
df_master_challenge_train = load_dataframe(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of loading the data here, it is passed to the function.

Path(config.data_paths.processed_attack_data_path),
"master_challenge_train.csv",
)
# Population data is used to pre-train some of the shadow models.
df_population_with_challenge = load_dataframe(
Path(config.data_paths.population_path),
"population_all_with_challenge.csv",
)
# Make sure master challenge train and population data have the "trans_id" column.
assert "trans_id" in df_master_challenge_train.columns, (
assert "trans_id" in df_challenge_train.columns, (
"trans_id column should be present in master train data for the shadow model pipeline."
)
assert "trans_id" in df_population_with_challenge.columns
assert "trans_id" in df_master_challenge_train.columns
assert "trans_id" in df_challenge_train.columns
# ``population_data`` in ensemble attack is used for shadow pre-training, and
# ``master_challenge_df`` is used for fine-tuning for half of the shadow models.
# For the other half of the shadow models, only ``master_challenge_df`` is used for training.
first_set_result_path, second_set_result_path, third_set_result_path = train_three_sets_of_shadow_models(
population_data=df_population_with_challenge,
master_challenge_data=df_master_challenge_train,
master_challenge_data=df_challenge_train,
shadow_models_output_path=Path(config.shadow_training.shadow_models_output_path),
training_json_config_paths=config.shadow_training.training_json_config_paths,
fine_tuning_config=config.shadow_training.fine_tuning_config,
Expand Down
Loading