Skip to content

Commit f17c349

Browse files
committed
option to calc surface dice per component
1 parent ad03103 commit f17c349

File tree

4 files changed

+120
-46
lines changed

4 files changed

+120
-46
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ scripts/cooper/training/find_rec_testset.py
1515
synapse-net-models/
1616
scripts/portal/upscale_tomo.py
1717
analysis_results/
18-
scripts/cooper/revision/evaluation_results/
18+
scripts/cooper/revision/evaluation_results/
19+
scripts/cooper/revision/export_tif_to_h5.py

run_sbatch_revision.sbatch

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
#! /bin/bash
22
#SBATCH -c 4 #4 #8
3-
#SBATCH --mem 120G #120G #32G #64G #256G
3+
#SBATCH --mem 256G #120G #32G #64G #256G
44
#SBATCH -p grete:shared #grete:shared #grete-h100:shared
5-
#SBATCH -t 4:00:00 #6:00:00 #48:00:00
5+
#SBATCH -t 6:00:00 #6:00:00 #48:00:00
66
#SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1
77
#SBATCH --output=/user/muth9/u12095/synapse-net/slurm_revision/slurm-%j.out
8-
#SBATCH -A nim00007 #SBATCH --constraint 80gb
8+
#SBATCH -A nim00007
9+
#SBATCH --constraint 80gb
910

1011
source ~/.bashrc
1112
conda activate synapse-net
12-
python scripts/cooper/revision/surface_dice.py -i /mnt/ceph-hdd/cold/nim00007/AZ_prediction_new/endbulb_of_held/ -gt /mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/endbulb_of_held/ -v 7
13+
python /user/muth9/u12095/synapse-net/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py \
14+
-i /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/exported/SNAP25/ \
15+
-o /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/afterRevision_analysis/boundaryT0_9_constantins_presynapticFiltering --store \
16+
-s ./analysis_results/man_subset

scripts/cooper/revision/surface_dice.py

Lines changed: 102 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
import pandas as pd
1212
from tqdm import tqdm
1313
import numpy as np
14+
from scipy.ndimage import label
15+
from skimage.measure import regionprops
1416

1517
from membrain_seg.segmentation.skeletonize import skeletonization
1618
from membrain_seg.benchmark.metrics import masked_surface_dice
1719

1820

1921
def load_segmentation(file_path, key):
20-
"""Load a dataset from an HDF5 file."""
2122
with h5py.File(file_path, "r") as f:
2223
data = f[key][:]
2324
return data
2425

2526

2627
def evaluate_surface_dice(pred, gt, raw, check):
27-
"""Skeletonize predictions and GT, compute surface dice."""
2828
gt_skeleton = skeletonization(gt == 1, batch_size=100000)
2929
pred_skeleton = skeletonization(pred, batch_size=100000)
3030
mask = gt != 2
@@ -33,10 +33,10 @@ def evaluate_surface_dice(pred, gt, raw, check):
3333
import napari
3434
v = napari.Viewer()
3535
v.add_image(raw)
36-
v.add_labels(gt, name= f"gt")
37-
v.add_labels(gt_skeleton.astype(np.uint16), name= f"gt_skeleton")
38-
v.add_labels(pred, name= f"pred")
39-
v.add_labels(pred_skeleton.astype(np.uint16), name= f"pred_skeleton")
36+
v.add_labels(gt, name="gt")
37+
v.add_labels(gt_skeleton.astype(np.uint16), name="gt_skeleton")
38+
v.add_labels(pred, name="pred")
39+
v.add_labels(pred_skeleton.astype(np.uint16), name="pred_skeleton")
4040
napari.run()
4141

4242
surf_dice, confusion_dict = masked_surface_dice(
@@ -45,28 +45,80 @@ def evaluate_surface_dice(pred, gt, raw, check):
4545
return surf_dice, confusion_dict
4646

4747

48-
def process_file(pred_path, gt_path, seg_key, gt_key, check):
49-
"""Process a single prediction/GT file pair."""
48+
def process_file(pred_path, gt_path, seg_key, gt_key, check,
49+
min_bb_shape=(32, 384, 384), min_thinning_size=2500,
50+
global_eval=False):
5051
try:
5152
pred = load_segmentation(pred_path, seg_key)
5253
gt = load_segmentation(gt_path, gt_key)
5354
raw = load_segmentation(gt_path, "raw")
54-
surf_dice, confusion = evaluate_surface_dice(pred, gt, raw, check)
5555

56-
result = {
57-
"tomo_name": os.path.basename(pred_path),
58-
"surface_dice": surf_dice,
59-
**confusion,
60-
}
61-
return result
56+
if global_eval:
57+
gt_bin = (gt == 1).astype(np.uint8)
58+
pred_bin = pred.astype(np.uint8)
59+
60+
dice, confusion = evaluate_surface_dice(pred_bin, gt_bin, raw, check)
61+
return [{
62+
"tomo_name": os.path.basename(pred_path),
63+
"gt_component_id": -1, # -1 indicates global eval
64+
"surface_dice": dice,
65+
**confusion
66+
}]
67+
68+
labeled_gt, _ = label(gt == 1)
69+
props = regionprops(labeled_gt)
70+
results = []
71+
72+
for prop in props:
73+
if prop.area < min_thinning_size:
74+
continue
75+
76+
comp_id = prop.label
77+
bbox_start = prop.bbox[:3]
78+
bbox_end = prop.bbox[3:]
79+
bbox = tuple(slice(start, stop) for start, stop in zip(bbox_start, bbox_end))
80+
81+
pad_width = [
82+
max(min_shape - (sl.stop - sl.start), 0) // 2
83+
for sl, min_shape in zip(bbox, min_bb_shape)
84+
]
85+
86+
expanded_bbox = tuple(
87+
slice(
88+
max(sl.start - pw, 0),
89+
min(sl.stop + pw, dim)
90+
)
91+
for sl, pw, dim in zip(bbox, pad_width, gt.shape)
92+
)
93+
94+
gt_crop = (labeled_gt[expanded_bbox] == comp_id).astype(np.uint8)
95+
pred_crop = pred[expanded_bbox].astype(np.uint8)
96+
raw_crop = raw[expanded_bbox]
97+
98+
try:
99+
dice, confusion = evaluate_surface_dice(pred_crop, gt_crop, raw_crop, check)
100+
except Exception as e:
101+
print(f"Error computing Dice for GT component {comp_id} in {pred_path}: {e}")
102+
continue
103+
104+
result = {
105+
"tomo_name": os.path.basename(pred_path),
106+
"gt_component_id": comp_id,
107+
"surface_dice": dice,
108+
**confusion
109+
}
110+
results.append(result)
111+
112+
return results
62113

63114
except Exception as e:
64115
print(f"Error processing {pred_path}: {e}")
65-
return None
116+
return []
66117

67118

68-
def collect_results(input_folder, gt_folder, version, check=False):
69-
"""Loop through prediction files and compute metrics."""
119+
def collect_results(input_folder, gt_folder, version, check=False,
120+
min_bb_shape=(32, 384, 384), min_thinning_size=2500,
121+
global_eval=False):
70122
results = []
71123
seg_key = f"predictions/az/seg_v{version}"
72124
gt_key = "/labels/az_merged"
@@ -83,29 +135,32 @@ def collect_results(input_folder, gt_folder, version, check=False):
83135
print(f"Warning: Ground truth file not found for {fname}")
84136
continue
85137

86-
result = process_file(pred_path, gt_path, seg_key, gt_key, check)
87-
if result:
88-
result["input_folder"] = input_folder_name
89-
results.append(result)
138+
file_results = process_file(
139+
pred_path, gt_path, seg_key, gt_key, check,
140+
min_bb_shape=min_bb_shape,
141+
min_thinning_size=min_thinning_size,
142+
global_eval=global_eval
143+
)
144+
145+
for res in file_results:
146+
res["input_folder"] = input_folder_name
147+
results.append(res)
90148

91149
return results
92150

93151

94152
def save_results(results, output_file):
95-
"""Append results to an Excel file, updating rows with matching tomo_name and input_folder."""
96153
new_df = pd.DataFrame(results)
97154

98155
if os.path.exists(output_file):
99156
existing_df = pd.read_excel(output_file)
100157

101-
# Drop rows where tomo_name and input_folder match any in new_df
102158
combined_df = existing_df[
103-
~existing_df.set_index(["tomo_name", "input_folder"]).index.isin(
104-
new_df.set_index(["tomo_name", "input_folder"]).index
159+
~existing_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index.isin(
160+
new_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index
105161
)
106162
]
107163

108-
# Append new data and reset index
109164
final_df = pd.concat([combined_df, new_df], ignore_index=True)
110165
else:
111166
final_df = new_df
@@ -114,20 +169,34 @@ def save_results(results, output_file):
114169
print(f"Results saved to {output_file}")
115170

116171

117-
118172
def main():
119-
parser = argparse.ArgumentParser(description="Compute surface dice for AZ segmentations.")
173+
parser = argparse.ArgumentParser(description="Compute surface dice per GT component or globally for AZ segmentations.")
120174
parser.add_argument("--input_folder", "-i", required=True, help="Folder with predicted segmentations (.h5)")
121175
parser.add_argument("--gt_folder", "-gt", required=True, help="Folder with ground truth segmentations (.h5)")
122176
parser.add_argument("--version", "-v", required=True, help="Version string used in prediction key")
123-
parser.add_argument("--check", action="store_true", help="Version string used in prediction key")
177+
parser.add_argument("--check", action="store_true", help="Visualize intermediate outputs in Napari")
178+
parser.add_argument("--global_eval", action="store_true", help="If set, compute global surface dice instead of per-component")
124179

125180
args = parser.parse_args()
126181

127-
output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{args.version}_surface_dice.xlsx"
128-
results = collect_results(args.input_folder, args.gt_folder, args.version, args.check)
182+
min_bb_shape = (32, 384, 384)
183+
min_thinning_size = 2500
184+
185+
suffix = "global" if args.global_eval else "per_gt_component"
186+
output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{args.version}_surface_dice_{suffix}.xlsx"
187+
188+
results = collect_results(
189+
args.input_folder,
190+
args.gt_folder,
191+
args.version,
192+
args.check,
193+
min_bb_shape=min_bb_shape,
194+
min_thinning_size=min_thinning_size,
195+
global_eval=args.global_eval
196+
)
197+
129198
save_results(results, output_file)
130199

131200

132201
if __name__ == "__main__":
133-
main()
202+
main()

scripts/cooper/revision/updated_data_analysis/store_results.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,20 @@ def save_filtered_dataframes(output_dir, tomogram_name, df):
7373
'AZ_distances_within_200': 200,
7474
'AZ_distances_within_100': 100,
7575
'AZ_distances_within_40': 40,
76-
'AZ_distances_within_40_with_diameters': 40,
77-
'AZ_distances_within_40_only_diameters': 40,
76+
'AZ_distances_within_100_with_diameters': 100,
77+
'AZ_distances_within_100_only_diameters': 100,
7878
}
7979

8080
for filename, max_dist in thresholds.items():
8181
file_path = os.path.join(output_dir, f"{filename}.xlsx")
8282
filtered_df = df if max_dist is None else df[df['distance'] <= max_dist]
8383

84-
if filename == 'AZ_distances_within_40_with_diameters':
84+
if filename == 'AZ_distances_within_100_with_diameters':
8585
data = pd.DataFrame({
8686
f"{tomogram_name}_distance": filtered_df['distance'].values,
8787
f"{tomogram_name}_diameter": filtered_df['diameter'].values
8888
})
89-
elif filename == 'AZ_distances_within_40_only_diameters':
89+
elif filename == 'AZ_distances_within_100_only_diameters':
9090
data = pd.DataFrame({
9191
f"{tomogram_name}_diameter": filtered_df['diameter'].values
9292
})
@@ -110,8 +110,8 @@ def save_filtered_dataframes_with_seg_id(output_dir, tomogram_name, df):
110110
'AZ_distances_within_200_with_seg_id': 200,
111111
'AZ_distances_within_100_with_seg_id': 100,
112112
'AZ_distances_within_40_with_seg_id': 40,
113-
'AZ_distances_within_40_with_diameters_and_seg_id': 40,
114-
'AZ_distances_within_40_only_diameters_and_seg_id': 40,
113+
'AZ_distances_within_100_with_diameters_and_seg_id': 100,
114+
'AZ_distances_within_100_only_diameters_and_seg_id': 100,
115115
}
116116

117117
with_segID_dir = os.path.join(output_dir, "with_segID")
@@ -121,13 +121,13 @@ def save_filtered_dataframes_with_seg_id(output_dir, tomogram_name, df):
121121
file_path = os.path.join(with_segID_dir, f"{filename}.xlsx")
122122
filtered_df = df if max_dist is None else df[df['distance'] <= max_dist]
123123

124-
if filename == 'AZ_distances_within_40_with_diameters_and_seg_id':
124+
if filename == 'AZ_distances_within_100_with_diameters_and_seg_id':
125125
data = pd.DataFrame({
126126
f"{tomogram_name}_seg_id": filtered_df['seg_id'].values,
127127
f"{tomogram_name}_distance": filtered_df['distance'].values,
128128
f"{tomogram_name}_diameter": filtered_df['diameter'].values
129129
})
130-
elif filename == 'AZ_distances_within_40_only_diameters_and_seg_id':
130+
elif filename == 'AZ_distances_within_100_only_diameters_and_seg_id':
131131
data = pd.DataFrame({
132132
f"{tomogram_name}_seg_id": filtered_df['seg_id'].values,
133133
f"{tomogram_name}_diameter": filtered_df['diameter'].values

0 commit comments

Comments
 (0)