Skip to content

Commit 1813d26

Browse files
committed
Packaged directory iterations, switched to the library's ENUM, resolved other comments.
1 parent 779c7c3 commit 1813d26

File tree

4 files changed

+59
-43
lines changed

4 files changed

+59
-43
lines changed

examples/common/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Functions used for attacks across multiple examples.
2+
3+
from collections.abc import Generator
4+
from pathlib import Path
5+
6+
7+
def iterate_model_folders(input_data_path: Path, diffusion_model_names: list[str]) -> Generator[tuple[str, Path, str]]:
8+
"""
9+
Iterates over the competition's shadow model folder structure and yields model information.
10+
11+
Args:
12+
input_data_path: The base path for the input data.
13+
diffusion_model_names: A list of diffusion model names to iterate over.
14+
15+
Yields:
16+
A tuple containing the model name, the path to the model's data, and the model folder name.
17+
"""
18+
modes = ["train", "dev", "final"]
19+
for model_name in diffusion_model_names:
20+
model_path = input_data_path / f"{model_name}_black_box"
21+
for mode in modes:
22+
current_path = model_path / mode
23+
if not current_path.exists():
24+
continue
25+
26+
model_folders = [entry for entry in current_path.iterdir() if entry.is_dir()]
27+
for model_folder_path in model_folders:
28+
yield model_name, model_folder_path, model_folder_path.name

examples/ept_attack/run_ept_attack.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import hydra
1414
from omegaconf import DictConfig
1515

16+
from examples.common.utils import iterate_model_folders
1617
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe, save_dataframe
1718
from midst_toolkit.attacks.ept.feature_extraction import extract_features
1819
from midst_toolkit.common.logger import log
@@ -32,7 +33,6 @@ def run_attribute_prediction(config: DictConfig) -> None:
3233
log(INFO, "Running attribute prediction model training.")
3334

3435
diffusion_model_names = ["tabddpm", "tabsyn"] if config.attack_settings.single_table else ["clavaddpm"]
35-
modes = ["train", "dev", "final"]
3636
input_data_path = Path(config.data_paths.input_data_path)
3737
output_features_path = Path(config.data_paths.output_data_path, "attribute_prediction_features")
3838

@@ -48,41 +48,33 @@ def run_attribute_prediction(config: DictConfig) -> None:
4848

4949
# TODO: Package iterating over competition structure (maybe into a utility function)
5050
# Iterating over directories specific to the shadow models folder structure in the competition
51-
for model_name in diffusion_model_names:
52-
model_path = input_data_path / f"{model_name}_black_box"
53-
for mode in modes:
54-
current_path = model_path / mode
51+
for model_name, model_data_path, model_folder in iterate_model_folders(input_data_path, diffusion_model_names):
52+
# Load the data files as dataframes
53+
df_synthetic_data = load_dataframe(model_data_path, "trans_synthetic.csv")
54+
df_challenge_data = load_dataframe(model_data_path, "challenge_with_id.csv")
5555

56-
model_folders = [entry.name for entry in current_path.iterdir() if entry.is_dir()]
57-
for model_folder in model_folders:
58-
# Load the data files as dataframes
59-
model_data_path = current_path / model_folder
56+
# Keep only the columns that are present in feature_column_types
57+
columns_to_keep = feature_column_types["numerical"] + feature_column_types["categorical"]
58+
df_synthetic_data = df_synthetic_data[columns_to_keep]
59+
df_challenge_data = df_challenge_data[columns_to_keep]
6060

61-
df_synthetic_data = load_dataframe(model_data_path, "trans_synthetic.csv")
62-
df_challenge_data = load_dataframe(model_data_path, "challenge_with_id.csv")
61+
# Run feature extraction
62+
df_extracted_features = extract_features(
63+
synthetic_data=df_synthetic_data,
64+
challenge_data=df_challenge_data,
65+
column_types=feature_column_types,
66+
random_seed=config.random_seed,
67+
)
6368

64-
# Keep only the columns that are present in feature_column_types
65-
columns_to_keep = feature_column_types["numerical"] + feature_column_types["categorical"]
66-
df_synthetic_data = df_synthetic_data[columns_to_keep]
67-
df_challenge_data = df_challenge_data[columns_to_keep]
69+
final_output_dir = output_features_path / f"{model_name}_black_box"
6870

69-
# Run feature extraction
70-
df_extracted_features = extract_features(
71-
synthetic_data=df_synthetic_data,
72-
challenge_data=df_challenge_data,
73-
column_types=feature_column_types,
74-
random_seed=config.random_seed,
75-
)
71+
final_output_dir.mkdir(parents=True, exist_ok=True)
7672

77-
final_output_dir = output_features_path / f"{model_name}_black_box"
73+
# Extract the number at the end of model_folder
74+
model_folder_number = int(model_folder.split("_")[-1])
75+
file_name = f"attribute_prediction_features_{model_folder_number}.csv"
7876

79-
final_output_dir.mkdir(parents=True, exist_ok=True)
80-
81-
# Extract the number at the end of model_folder
82-
model_folder_number = int(model_folder.split("_")[-1])
83-
file_name = f"attribute_prediction_features_{model_folder_number}.csv"
84-
85-
save_dataframe(df=df_extracted_features, file_path=final_output_dir, file_name=file_name)
77+
save_dataframe(df=df_extracted_features, file_path=final_output_dir, file_name=file_name)
8678

8779

8880
@hydra.main(config_path=".", config_name="config", version_base=None)

src/midst_toolkit/attacks/ept/feature_extraction.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
66
"""
77

8-
from enum import Enum
98
from logging import INFO
109

1110
import numpy as np
@@ -15,14 +14,10 @@
1514
from sklearn.pipeline import Pipeline
1615
from sklearn.preprocessing import OneHotEncoder, StandardScaler
1716

17+
from midst_toolkit.common.enumerations import TaskType
1818
from midst_toolkit.common.logger import log
1919

2020

21-
class TaskType(Enum):
22-
CLASSIFICATION = "classification"
23-
REGRESSION = "regression"
24-
25-
2621
def preprocess_train_predict(
2722
train_points: pd.DataFrame,
2823
test_points: pd.DataFrame,
@@ -77,7 +72,7 @@ def preprocess_train_predict(
7772
"The union of numeric_columns and categorical_columns must match the columns in the combined dataframe"
7873
)
7974

80-
task_type = TaskType.CLASSIFICATION if target_col in categorical_columns else TaskType.REGRESSION
75+
task_type = TaskType.MULTICLASS_CLASSIFICATION if target_col in categorical_columns else TaskType.REGRESSION
8176

8277
# Remove target column from feature columns
8378
numeric_columns = [col for col in numeric_columns if col != target_col]
@@ -95,7 +90,7 @@ def preprocess_train_predict(
9590

9691
model = (
9792
RandomForestClassifier(random_state=random_seed)
98-
if task_type == TaskType.CLASSIFICATION
93+
if task_type == TaskType.MULTICLASS_CLASSIFICATION
9994
else RandomForestRegressor(random_state=random_seed)
10095
)
10196

@@ -124,8 +119,8 @@ def extract_features(
124119
4. Compile the results into a DataFrame.
125120
126121
Args:
127-
synthetic_data: Synthetic data generated by the target model without ID columns; the data we want to
128-
extract features from.
122+
synthetic_data: Synthetic data to extract features from. Note: This data should not contain any identifier
123+
columns, as the function will attempt to train a prediction model for every column included.
129124
challenge_data: The data the predictions are compared against, to compute prediction accuracy/errors.
130125
column_types: A dictionary specifying the types of columns (numerical or categorical) in the data.
131126
random_seed: Random seed for reproducibility. Defaults to None.
@@ -160,7 +155,7 @@ def extract_features(
160155
features.append(y_test)
161156
columns.append(column)
162157

163-
if task_type == TaskType.CLASSIFICATION:
158+
if task_type == TaskType.MULTICLASS_CLASSIFICATION:
164159
# TODO: Maybe change the variable name from accuracy to correctness
165160
# Calculate accuracy
166161
accuracy = predictions == y_test

tests/unit/attacks/ept_attack/test_feature_extraction.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pandas as pd
33
import pytest
44

5-
from midst_toolkit.attacks.ept.feature_extraction import TaskType, extract_features, preprocess_train_predict
5+
from midst_toolkit.attacks.ept.feature_extraction import extract_features, preprocess_train_predict
6+
from midst_toolkit.common.enumerations import TaskType
67

78

89
@pytest.fixture
@@ -51,7 +52,7 @@ def test_preprocess_train_predict_classification(sample_dataframes, sample_colum
5152
random_seed=42,
5253
)
5354

54-
assert task_type == TaskType.CLASSIFICATION
55+
assert task_type == TaskType.MULTICLASS_CLASSIFICATION
5556
assert len(predictions) == len(test_df)
5657
assert predictions.dtype == "object" # RandomForestClassifier predicts original class
5758
pd.testing.assert_series_equal(y_test, test_df[target_col], check_dtype=False, check_names=False)

0 commit comments

Comments
 (0)