Skip to content

Commit 779c7c3

Browse files
committed
Fix test
1 parent 515f20c commit 779c7c3

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

examples/ept_attack/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@ pipeline:
1616
run_feature_extraction: true # Whether to run attribute prediction model training
1717
run_attack_classifier_training: false # Whether to run attack classifier training
1818

19+
attack_settings:
20+
single_table: true # Whether the data is single-table
21+
1922
# General settings
2023
random_seed: 42

src/midst_toolkit/attacks/ept/feature_extraction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def preprocess_train_predict(
5454
y_test: True values for the target column on the test data.
5555
task_type: Whether the attribution prediction model was a classification or regression model.
5656
"""
57-
5857
assert target_col in train_points.columns, f"Target column '{target_col}' not found in train_points."
5958
assert target_col in test_points.columns, f"Target column '{target_col}' not found in test_points."
6059

@@ -162,7 +161,7 @@ def extract_features(
162161
columns.append(column)
163162

164163
if task_type == TaskType.CLASSIFICATION:
165-
#TODO: Maybe change the variable name from accuracy to correctness
164+
# TODO: Maybe change the variable name from accuracy to correctness
166165
# Calculate accuracy
167166
accuracy = predictions == y_test
168167
accuracy = accuracy.astype(int)

tests/unit/attacks/ept_attack/test_feature_extraction.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,33 @@ def test_preprocess_train_predict_assertions(sample_dataframes, sample_column_ty
8181
# Tests that the assertions within preprocess_train_predict fire correctly.
8282

8383
train_df, test_df = sample_dataframes
84+
target_col = "num_col_1"
8485

85-
# Test mismatching columns
86-
test_df_mismatch = test_df.drop(columns=["num_col_1"])
87-
with pytest.raises(AssertionError, match="Columns in df_train and df_test do not match"):
88-
preprocess_train_predict(train_df, test_df_mismatch, "cat_col_1", sample_column_types)
86+
# 1. Target column not in train_points
87+
with pytest.raises(AssertionError, match="Target column 'non_existent_col' not found in train_points."):
88+
preprocess_train_predict(train_df, test_df, "non_existent_col", sample_column_types)
8989

90-
# Test target_col not in column_types
91-
with pytest.raises(AssertionError, match="must appear exactly once"):
92-
preprocess_train_predict(train_df, test_df, "missing_col", sample_column_types)
90+
# 2. Target column not in test_points
91+
test_df_missing_target = test_df.drop(columns=[target_col])
92+
with pytest.raises(AssertionError, match=f"Target column '{target_col}' not found in test_points."):
93+
preprocess_train_predict(train_df, test_df_missing_target, target_col, sample_column_types)
9394

94-
# Test column_types not matching dataframe columns
95-
bad_column_types = {"numerical": ["num_col_1"], "categorical": []}
96-
with pytest.raises(AssertionError, match="must match the columns in the combined dataframe"):
97-
preprocess_train_predict(train_df, test_df, "num_col_1", bad_column_types)
95+
# 3. Mismatched columns between train and test dataframes
96+
test_df_mismatched = test_df.rename(columns={"num_col_2": "new_col_name"})
97+
with pytest.raises(AssertionError, match="Columns in df_train and df_test do not match"):
98+
preprocess_train_predict(train_df, test_df_mismatched, target_col, sample_column_types)
99+
100+
# 4. Target column appears more than once in column_types
101+
column_types_duplicate = sample_column_types.copy()
102+
column_types_duplicate["categorical"] = column_types_duplicate["categorical"] + [target_col]
103+
with pytest.raises(AssertionError, match=f"The target column '{target_col}' must appear exactly once"):
104+
preprocess_train_predict(train_df, test_df, target_col, column_types_duplicate)
105+
106+
# 5. Mismatch between dataframe columns and column_types
107+
column_types_mismatch = sample_column_types.copy()
108+
column_types_mismatch["numerical"] = ["num_col_1"] # Missing num_col_2
109+
with pytest.raises(AssertionError, match="The union of numeric_columns and categorical_columns must match"):
110+
preprocess_train_predict(train_df, test_df, target_col, column_types_mismatch)
98111

99112

100113
def test_main_feature_extraction(sample_dataframes, sample_column_types):

0 commit comments

Comments
 (0)