Skip to content

Commit 5b793a4

Browse files
author
Behnoosh Zamanlooy
committed
addressed some code rabbit comments
1 parent b0f26b3 commit 5b793a4

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

src/midst_toolkit/attacks/tf/classification.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,12 @@ def evaluate_model(model, x, y):
117117
optimizer = optim.Adam(regression_model.parameters(), lr=learning_rate)
118118
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119119

120+
has_validation = x_val is not None
120121
x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
121122
y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
122-
x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
123-
y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device)
123+
if has_validation:
124+
x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
125+
y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)
124126

125127
indices = torch.randperm(x_train.size(0))
126128
x_train, y_train = x_train[indices], y_train[indices]
@@ -137,7 +139,7 @@ def evaluate_model(model, x, y):
137139
if (epoch + 1) % 100 == 0:
138140
train_loss, train_tpr = evaluate_model(regression_model, x_train, y_train)
139141
if x_val is not None:
140-
test_loss, test_tpr = evaluate_model(regression_model, x_val, y_test)
142+
test_loss, test_tpr = evaluate_model(regression_model, x_val, y_val)
141143
if test_tpr > best_tpr:
142144
best_tpr = test_tpr
143145
save_best_model(regression_model, best_model_path)
@@ -153,7 +155,7 @@ def evaluate_model(model, x, y):
153155
load_best_model(regression_model, best_model_path, device)
154156

155157
if x_val is not None:
156-
test_loss, test_tpr = evaluate_model(regression_model, x_val, y_test)
158+
test_loss, test_tpr = evaluate_model(regression_model, x_val, y_val)
157159
print(f"Final best loss: {test_loss}, best TPR: {test_tpr}")
158160

159161
return regression_model

src/midst_toolkit/attacks/tf/data_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ def get_tpr_at_fpr(true_membership: list[int], predictions: list[float], max_fpr
176176
Calculates the best True Positive Rate when the False Positive Rate is at most `max_fpr`.
177177
"""
178178
fpr, tpr, _ = roc_curve(true_membership, predictions)
179-
return max(tpr[fpr < max_fpr])
179+
180+
valid_tpr = tpr[fpr <= max_fpr]
181+
if len(valid_tpr) == 0:
182+
raise ValueError("No valid TPR values found for the given max FPR.")
183+
return float(max(valid_tpr))
180184

181185

182186
def evaluate_attack_performance(

tests/integration/attacks/tf/test_tf_attack.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_tf_attack_whitebox_small_config():
3434
"timesteps_list": [5],
3535
"addt_value_list": [0],
3636
"predictions_file_format": "test",
37-
"results_path": Path("/h/behnzaman/midst-toolkit/tests/integration/attacks/tf/test_tf_attack_results"),
37+
"results_path": Path("tests/integration/attacks/tf/test_tf_attack_results"),
3838
"use_best_checkpoint": True,
3939
"test_indices": [5],
4040
"train_indices": [1, 2],
@@ -44,10 +44,13 @@ def test_tf_attack_whitebox_small_config():
4444
mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config)
4545
tpr_at_fpr_train, roc_auc_train = mia_performance_train.values()
4646
tpr_at_fpr_val, roc_auc_val = mia_performance_val.values()
47-
(
48-
tpr_at_fpr_test,
49-
roc_auc_test,
50-
) = mia_performance_test.values()
47+
tpr_at_fpr_test, roc_auc_test = mia_performance_test.values()
48+
tpr_at_fpr_train = mia_performance_train["max_tpr"]
49+
roc_auc_train = mia_performance_train["roc_auc"]
50+
tpr_at_fpr_val = mia_performance_val["max_tpr"]
51+
roc_auc_val = mia_performance_val["roc_auc"]
52+
tpr_at_fpr_test = mia_performance_test["max_tpr"]
53+
roc_auc_test = mia_performance_test["roc_auc"]
5154

5255
assert roc_auc_train == pytest.approx(0.48133750000000003, abs=1e-8)
5356
assert tpr_at_fpr_train == pytest.approx(0.125, abs=1e-8)

0 commit comments

Comments
 (0)