1313import hydra
1414from omegaconf import DictConfig
1515
16+ from examples .common .utils import iterate_model_folders
1617from midst_toolkit .attacks .ensemble .data_utils import load_dataframe , save_dataframe
1718from midst_toolkit .attacks .ept .feature_extraction import extract_features
1819from midst_toolkit .common .logger import log
@@ -32,7 +33,6 @@ def run_attribute_prediction(config: DictConfig) -> None:
3233 log (INFO , "Running attribute prediction model training." )
3334
3435 diffusion_model_names = ["tabddpm" , "tabsyn" ] if config .attack_settings .single_table else ["clavaddpm" ]
35- modes = ["train" , "dev" , "final" ]
3636 input_data_path = Path (config .data_paths .input_data_path )
3737 output_features_path = Path (config .data_paths .output_data_path , "attribute_prediction_features" )
3838
@@ -48,41 +48,33 @@ def run_attribute_prediction(config: DictConfig) -> None:
4848
4949 # TODO: Package iterating over competition structure (maybe into a utility function)
5050 # Iterating over directories specific to the shadow models folder structure in the competition
51- for model_name in diffusion_model_names :
52- model_path = input_data_path / f" { model_name } _black_box"
53- for mode in modes :
54- current_path = model_path / mode
51+ for model_name , model_data_path , model_folder in iterate_model_folders ( input_data_path , diffusion_model_names ) :
52+ # Load the data files as dataframes
53+ df_synthetic_data = load_dataframe ( model_data_path , "trans_synthetic.csv" )
54+ df_challenge_data = load_dataframe ( model_data_path , "challenge_with_id.csv" )
5555
56- model_folders = [ entry . name for entry in current_path . iterdir () if entry . is_dir ()]
57- for model_folder in model_folders :
58- # Load the data files as dataframes
59- model_data_path = current_path / model_folder
56+ # Keep only the columns that are present in feature_column_types
57+ columns_to_keep = feature_column_types [ "numerical" ] + feature_column_types [ "categorical" ]
58+ df_synthetic_data = df_synthetic_data [ columns_to_keep ]
59+ df_challenge_data = df_challenge_data [ columns_to_keep ]
6060
61- df_synthetic_data = load_dataframe (model_data_path , "trans_synthetic.csv" )
62- df_challenge_data = load_dataframe (model_data_path , "challenge_with_id.csv" )
61+ # Run feature extraction
62+ df_extracted_features = extract_features (
63+ synthetic_data = df_synthetic_data ,
64+ challenge_data = df_challenge_data ,
65+ column_types = feature_column_types ,
66+ random_seed = config .random_seed ,
67+ )
6368
64- # Keep only the columns that are present in feature_column_types
65- columns_to_keep = feature_column_types ["numerical" ] + feature_column_types ["categorical" ]
66- df_synthetic_data = df_synthetic_data [columns_to_keep ]
67- df_challenge_data = df_challenge_data [columns_to_keep ]
69+ final_output_dir = output_features_path / f"{ model_name } _black_box"
6870
69- # Run feature extraction
70- df_extracted_features = extract_features (
71- synthetic_data = df_synthetic_data ,
72- challenge_data = df_challenge_data ,
73- column_types = feature_column_types ,
74- random_seed = config .random_seed ,
75- )
71+ final_output_dir .mkdir (parents = True , exist_ok = True )
7672
77- final_output_dir = output_features_path / f"{ model_name } _black_box"
73+ # Extract the number at the end of model_folder
74+ model_folder_number = int (model_folder .split ("_" )[- 1 ])
75+ file_name = f"attribute_prediction_features_{ model_folder_number } .csv"
7876
79- final_output_dir .mkdir (parents = True , exist_ok = True )
80-
81- # Extract the number at the end of model_folder
82- model_folder_number = int (model_folder .split ("_" )[- 1 ])
83- file_name = f"attribute_prediction_features_{ model_folder_number } .csv"
84-
85- save_dataframe (df = df_extracted_features , file_path = final_output_dir , file_name = file_name )
77+ save_dataframe (df = df_extracted_features , file_path = final_output_dir , file_name = file_name )
8678
8779
8880@hydra .main (config_path = "." , config_name = "config" , version_base = None )
0 commit comments