Skip to content

Commit 5c813e9

Browse files
committed
Tidy remaining fraud detection review fixes
1 parent 668aab3 commit 5c813e9

File tree

3 files changed

+3
-18
lines changed

3 files changed

+3
-18
lines changed

research/fsi-fraud-detection/stats/client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ def load_data(self, fl_ctx: FLContext) -> Dict[str, pd.DataFrame]:
7575
)
7676
for path in test_data_paths:
7777
self.log_info(fl_ctx, f" - {path}")
78-
79-
assert len(test_data_paths) == 4, "Expected 4 test files, got " + str(len(test_data_paths))
8078
else:
8179
# Single test file
8280
if not os.path.isfile(test_data_path_pattern):

research/fsi-fraud-detection/train/central_train.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ def main():
110110
print(f"Found {len(test_data_paths)} test files matching pattern: {test_data_path_pattern}")
111111
for path in test_data_paths:
112112
print(f" - {path}")
113-
114-
# assert len(test_data_paths) == 25, "Expected 25 test files, got " + str(len(test_data_paths))
115-
assert len(test_data_paths) == 20, "Expected 20 test files, got " + str(
116-
len(test_data_paths)
117-
) # Datasets later than 3/2/2026
118113
else:
119114
# Single test file
120115
if not os.path.isfile(test_data_path_pattern):

research/fsi-fraud-detection/train/client.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,6 @@
4141
from nvflare.app_opt.pt.fedproxloss import PTFedProxLoss
4242
from nvflare.client.tracking import MLflowWriter
4343

44-
# install dependencies
45-
# os.system("python -m pip install opacus")
46-
# os.system("python -m pip install captum")
47-
# result = os.system("python -m pip install numpy==1.26.4")
48-
# print(f"Pip Install Result: {result}")
49-
50-
5144
PATH = "pt_model.weights.pth"
5245

5346

@@ -154,9 +147,6 @@ def main():
154147
print(f"Found {len(test_data_paths)} test files matching pattern: {test_data_path_pattern}")
155148
for path in test_data_paths:
156149
print(f" - {path}")
157-
158-
# assert len(test_data_paths) == 5, "Expected 5 test files, got " + str(len(test_data_paths))
159-
assert len(test_data_paths) == 4, "Expected 4 test files, got " + str(len(test_data_paths)) # new data 3/3/2026
160150
else:
161151
# Single test file
162152
if not os.path.isfile(test_data_path_pattern):
@@ -566,8 +556,10 @@ def main():
566556
print("[WARNING] Skip SHAP with DP")
567557
if shap_metrics:
568558
print(f"SHAP computation completed. Used {shap_metrics['shap_samples_used']} samples.")
569-
else:
559+
elif run_shap:
570560
print("SHAP computation failed. Skipping SHAP metrics.")
561+
else:
562+
print("SHAP computation skipped for this round.")
571563
metrics["shap_metrics"] = shap_metrics
572564

573565
# (6) construct trained FL model (A dict of {parameter name: parameter weights} from the PyTorch model)

0 commit comments

Comments
 (0)