Skip to content

Commit d966ea3

Browse files
committed
minor things for analysis; 1st implementation of surface dice for eval
1 parent 7b83139 commit d966ea3

File tree

5 files changed

+129
-13
lines changed

5 files changed

+129
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ scripts/cooper/training/find_rec_testset.py
1515
synapse-net-models/
1616
scripts/portal/upscale_tomo.py
1717
analysis_results/
18+
scripts/cooper/revision/evaluation_results/

run_sbatch_revision.sbatch

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#! /bin/bash
22
#SBATCH -c 4 #4 #8
3-
#SBATCH --mem 256G #120G #32G #64G #256G
3+
#SBATCH --mem 120G #120G #32G #64G #256G
44
#SBATCH -p grete:shared #grete:shared #grete-h100:shared
55
#SBATCH -t 4:00:00 #6:00:00 #48:00:00
66
#SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1
77
#SBATCH --output=/user/muth9/u12095/synapse-net/slurm_revision/slurm-%j.out
8-
#SBATCH -A nim00007
9-
#SBATCH --constraint 80gb
8+
#SBATCH -A nim00007 #SBATCH --constraint 80gb
109

1110
source ~/.bashrc
1211
conda activate synapse-net
1312
python /user/muth9/u12095/synapse-net/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py \
14-
-i /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/exported/SNAP25/ \
15-
-o /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/afterRevision_analysis/boundaryT0_9_constantins_presynapticFiltering --store
13+
-i /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/final_Imig2014_seg_autoComp/SNAP25/ \
14+
-o /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/afterRevision_analysis/boundaryT0_9_constantins_presynapticFiltering/full_dataset --store \
15+
-s ./analysis_results/full_dataset
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import sys
2+
import os
3+
4+
# Add membrain-seg to Python path
5+
MEMBRAIN_SEG_PATH = "/user/muth9/u12095/membrain-seg/src"
6+
if MEMBRAIN_SEG_PATH not in sys.path:
7+
sys.path.insert(0, MEMBRAIN_SEG_PATH)
8+
9+
import argparse
10+
import h5py
11+
import pandas as pd
12+
from tqdm import tqdm
13+
import numpy as np
14+
15+
from membrain_seg.segmentation.skeletonize import skeletonization
16+
from membrain_seg.benchmark.metrics import masked_surface_dice
17+
18+
19+
def load_segmentation(file_path, key):
20+
"""Load a dataset from an HDF5 file."""
21+
with h5py.File(file_path, "r") as f:
22+
data = f[key][:]
23+
return data
24+
25+
26+
def evaluate_surface_dice(pred, gt, raw, check):
27+
"""Skeletonize predictions and GT, compute surface dice."""
28+
gt_skeleton = skeletonization(gt == 1, batch_size=100000)
29+
pred_skeleton = skeletonization(pred, batch_size=100000)
30+
mask = gt != 2
31+
32+
if check:
33+
import napari
34+
v = napari.Viewer()
35+
v.add_image(raw)
36+
v.add_labels(gt, name= f"gt")
37+
v.add_labels(gt_skeleton.astype(np.uint16), name= f"gt_skeleton")
38+
v.add_labels(pred, name= f"pred")
39+
v.add_labels(pred_skeleton.astype(np.uint16), name= f"pred_skeleton")
40+
napari.run()
41+
42+
surf_dice, confusion_dict = masked_surface_dice(
43+
pred_skeleton, gt_skeleton, pred, gt, mask
44+
)
45+
return surf_dice, confusion_dict
46+
47+
48+
def process_file(pred_path, gt_path, seg_key, gt_key, check):
49+
"""Process a single prediction/GT file pair."""
50+
try:
51+
pred = load_segmentation(pred_path, seg_key)
52+
gt = load_segmentation(gt_path, gt_key)
53+
raw = load_segmentation(gt_path, "raw")
54+
surf_dice, confusion = evaluate_surface_dice(pred, gt, raw, check)
55+
56+
result = {
57+
"tomo_name": os.path.basename(pred_path),
58+
"surface_dice": surf_dice,
59+
**confusion,
60+
}
61+
return result
62+
63+
except Exception as e:
64+
print(f"Error processing {pred_path}: {e}")
65+
return None
66+
67+
68+
def collect_results(input_folder, gt_folder, version, check=False):
69+
"""Loop through prediction files and compute metrics."""
70+
results = []
71+
seg_key = f"predictions/az/seg_v{version}"
72+
gt_key = "/labels/az_merged"
73+
74+
for fname in tqdm(os.listdir(input_folder), desc="Processing segmentations"):
75+
if not fname.endswith(".h5"):
76+
continue
77+
78+
pred_path = os.path.join(input_folder, fname)
79+
gt_path = os.path.join(gt_folder, fname)
80+
81+
if not os.path.exists(gt_path):
82+
print(f"Warning: Ground truth file not found for {fname}")
83+
continue
84+
85+
result = process_file(pred_path, gt_path, seg_key, gt_key, check)
86+
if result:
87+
results.append(result)
88+
89+
return results
90+
91+
92+
def save_results(results, output_file):
93+
"""Save results as an Excel file."""
94+
df = pd.DataFrame(results)
95+
df.to_excel(output_file, index=False)
96+
print(f"Results saved to {output_file}")
97+
98+
99+
def main():
100+
parser = argparse.ArgumentParser(description="Compute surface dice for AZ segmentations.")
101+
parser.add_argument("--input_folder", "-i", required=True, help="Folder with predicted segmentations (.h5)")
102+
parser.add_argument("--gt_folder", "-gt", required=True, help="Folder with ground truth segmentations (.h5)")
103+
parser.add_argument("--version", "-v", required=True, help="Version string used in prediction key")
104+
parser.add_argument("--check", action="store_true", help="Version string used in prediction key")
105+
106+
args = parser.parse_args()
107+
108+
output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{args.version}_surface_dice.xlsx"
109+
results = collect_results(args.input_folder, args.gt_folder, args.version, args.check)
110+
save_results(results, output_file)
111+
112+
113+
if __name__ == "__main__":
114+
main()

scripts/cooper/revision/updated_data_analysis/analysis_segmentations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def SV_pred(raw: np.ndarray, SV_model: str, output_path: str = None, store: bool
6262

6363
use_existing_seg = False
6464
#checking if segmentation is already in output path and if so, use it
65-
if output_path:
66-
with h5py.File(output_path, "a") as f:
65+
if output_path and os.path.exists(output_path):
66+
with h5py.File(output_path, "r") as f:
6767
if seg_key in f:
6868
seg = f[seg_key][:]
6969
use_existing_seg = True
@@ -108,10 +108,11 @@ def compartment_pred(raw: np.ndarray, compartment_model: str, output_path: str =
108108

109109
use_existing_seg = False
110110
#checking if segmentation is already in output path and if so, use it
111-
if output_path:
112-
with h5py.File(output_path, "a") as f:
113-
if seg_key in f:
111+
if output_path and os.path.exists(output_path):
112+
with h5py.File(output_path, "r") as f:
113+
if seg_key in f and pred_key in f:
114114
seg = f[seg_key][:]
115+
pred = f[pred_key][:]
115116
use_existing_seg = True
116117
print(f"Using existing compartment seg in {output_path}")
117118

@@ -152,8 +153,8 @@ def AZ_pred(raw: np.ndarray, AZ_model: str, output_path: str = None, store: bool
152153

153154
use_existing_seg = False
154155
#checking if segmentation is already in output path and if so, use it
155-
if output_path:
156-
with h5py.File(output_path, "a") as f:
156+
if output_path and os.path.exists(output_path):
157+
with h5py.File(output_path, "r") as f:
157158
if seg_key in f:
158159
seg = f[seg_key][:]
159160
use_existing_seg = True

scripts/cooper/revision/updated_data_analysis/run_data_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def main():
8080
run_data_analysis(input_path, output_path, store, resolution, analysis_output)
8181

8282
elif os.path.isdir(input_path):
83-
h5_files = [file for file in os.listdir(input_path) if file.endswith(".h5")]
83+
h5_files = sorted([file for file in os.listdir(input_path) if file.endswith(".h5")])
8484
for file in tqdm(h5_files, desc="Processing files"):
8585
full_input_path = os.path.join(input_path, file)
8686
output_path = os.path.join(output_folder, file) if output_folder else None

0 commit comments

Comments
 (0)