Skip to content

Commit f7449e1

Browse files
committed
3D segmentation with cellpose-sam
1 parent bfaddcb commit f7449e1

File tree

3 files changed

+86
-25
lines changed

3 files changed

+86
-25
lines changed

scripts/baselines/cellpose-sam_IHC.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
timer_output = os.path.join(out_dir, f"{basename}_timer.json")
5555

5656
masks, flows, styles = model.eval(img, batch_size=32, flow_threshold=flow_threshold,
57-
cellprob_threshold=cellprob_threshold,
57+
cellprob_threshold=cellprob_threshold, do_3D=True, z_axis=0,
5858
normalize={"tile_norm_blocksize": tile_norm_blocksize})
5959
io.imsave(out_path, masks)
6060

scripts/baselines/cellpose-sam_SGN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
timer_output = os.path.join(out_dir, f"{basename}_timer.json")
5555

5656
masks, flows, styles = model.eval(img, batch_size=32, flow_threshold=flow_threshold,
57-
cellprob_threshold=cellprob_threshold,
57+
cellprob_threshold=cellprob_threshold, do_3D=True, z_axis=0,
5858
normalize={"tile_norm_blocksize": tile_norm_blocksize})
5959
io.imsave(out_path, masks)
6060

scripts/baselines/eval_baseline.py

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ def filter_seg(seg_arr: np.typing.ArrayLike, min_count: int = 3000, max_count: i
1919
Returns:
2020
Filtered segmentation
2121
"""
22-
segmentation_ids = np.unique(seg_arr)[1:]
23-
seg_counts = [np.count_nonzero(seg_arr == seg_id) for seg_id in segmentation_ids]
24-
seg_filtered = [idx for idx, seg_count in zip(segmentation_ids, seg_counts) if min_count <= seg_count <= max_count]
25-
for s in segmentation_ids:
26-
if s not in seg_filtered:
27-
seg_arr[seg_arr == s] = 0
22+
labels, counts = np.unique(seg_arr, return_counts=True)
23+
valid = labels[(counts >= min_count) & (counts <= max_count)]
24+
mask = np.isin(seg_arr, valid)
25+
seg_arr[~mask] = 0
2826
return seg_arr
2927

3028

@@ -62,14 +60,19 @@ def eval_all_sgn():
6260
"final_consensus_annotations")
6361

6462
baselines = [
63+
"spiner2D",
6564
"cellpose3",
66-
"cellpose-sam",
65+
"cellpose-sam_2025-10",
6766
"distance_unet",
6867
"micro-sam",
69-
"stardist"]
68+
"stardist",
69+
]
7070

7171
for baseline in baselines:
72-
eval_segmentation(os.path.join(seg_dir, baseline), annotation_dir=annotation_dir)
72+
if "spiner" in baseline:
73+
eval_segmentation_spiner(os.path.join(seg_dir, baseline), annotation_dir=annotation_dir)
74+
else:
75+
eval_segmentation(os.path.join(seg_dir, baseline), annotation_dir=annotation_dir)
7376

7477

7578
def eval_all_ihc():
@@ -80,9 +83,10 @@ def eval_all_ihc():
8083
annotation_dir = os.path.join(cochlea_dir, "AnnotatedImageCrops/F1ValidationIHCs/consensus_annotation")
8184
baselines = [
8285
"cellpose3",
83-
"cellpose-sam",
84-
"distance_unet_v3",
85-
"micro-sam"]
86+
# "cellpose-sam_2025-11",
87+
"distance_unet_v4b",
88+
"micro-sam",
89+
]
8690

8791
for baseline in baselines:
8892
eval_segmentation(os.path.join(seg_dir, baseline), annotation_dir=annotation_dir)
@@ -98,10 +102,10 @@ def eval_segmentation(seg_dir, annotation_dir):
98102
basename = os.path.basename(seg)
99103
basename = ".".join(basename.split(".")[:-1])
100104
basename = "".join(basename.split("_seg")[0])
101-
print(basename)
102-
print("Annotation_dir", annotation_dir)
105+
# print("Annotation_dir", annotation_dir)
103106
dic_out = os.path.join(seg_dir, f"{basename}_dic.json")
104107
if not os.path.isfile(dic_out):
108+
print(basename)
105109

106110
df_path = os.path.join(annotation_dir, f"{basename}.csv")
107111
df = pd.read_csv(df_path, sep=",")
@@ -110,6 +114,7 @@ def eval_segmentation(seg_dir, annotation_dir):
110114
timer_dic = json.load(f)
111115

112116
seg_arr = imageio.imread(seg)
117+
print(f"shape {seg_arr.shape}")
113118
seg_filtered = filter_seg(seg_arr=seg_arr)
114119

115120
seg_dic = compute_matches_for_annotated_slice(segmentation=seg_filtered,
@@ -123,7 +128,54 @@ def eval_segmentation(seg_dir, annotation_dir):
123128

124129
seg_dicts.append(seg_dic)
125130
else:
126-
print(f"Dictionary {dic_out} already exists")
131+
print(f"Dictionary for {basename} already exists")
132+
133+
json_out = os.path.join(seg_dir, "eval_seg.json")
134+
with open(json_out, "w") as f:
135+
json.dump(seg_dicts, f, indent='\t', separators=(',', ': '))
136+
137+
138+
def eval_segmentation_spiner(seg_dir, annotation_dir):
139+
print(f"Evaluating segmentation in directory {seg_dir}")
140+
annots = [entry.path for entry in os.scandir(seg_dir)
141+
if entry.is_file() and ".csv" in entry.path]
142+
143+
seg_dicts = []
144+
for annot in annots:
145+
146+
basename = os.path.basename(annot)
147+
basename = ".".join(basename.split(".")[:-1])
148+
basename = "".join(basename.split("_annot")[0])
149+
dic_out = os.path.join(seg_dir, f"{basename}_dic.json")
150+
if not os.path.isfile(dic_out):
151+
152+
df_path = os.path.join(annotation_dir, f"{basename}.csv")
153+
df = pd.read_csv(df_path, sep=",")
154+
155+
image_spiner = os.path.join(seg_dir, f"{basename}.tif")
156+
img = imageio.imread(image_spiner)
157+
seg_arr = np.zeros(img.shape)
158+
159+
df_annot = pd.read_csv(annot, sep=",")
160+
for num, row in df_annot.iterrows():
161+
x1 = int(row["x1"])
162+
x2 = int(row["x2"])
163+
y1 = int(row["y1"])
164+
y2 = int(row["y2"])
165+
seg_arr[x1:x2, y1:y2] = num + 1
166+
167+
seg_dic = compute_matches_for_annotated_slice(segmentation=seg_arr,
168+
annotations=df,
169+
matching_tolerance=5)
170+
seg_dic["annotation_length"] = len(df)
171+
seg_dic["crop_name"] = basename
172+
seg_dic["time"] = None
173+
174+
eval_seg_dict(seg_dic, dic_out)
175+
176+
seg_dicts.append(seg_dic)
177+
else:
178+
print(f"Dictionary for {basename} already exists")
127179

128180
json_out = os.path.join(seg_dir, "eval_seg.json")
129181
with open(json_out, "w") as f:
@@ -145,6 +197,10 @@ def print_accuracy(eval_dir):
145197
fp = len(d["fp"])
146198
fn = len(d["fn"])
147199
time = d["time"]
200+
if time is None:
201+
show_time = False
202+
else:
203+
show_time = True
148204

149205
if tp + fp != 0:
150206
precision = tp / (tp + fp)
@@ -163,9 +219,13 @@ def print_accuracy(eval_dir):
163219
recall_list.append(recall)
164220
f1_score_list.append(f1_score)
165221
time_list.append(time)
166-
167-
names = ["Precision", "Recall", "F1 score", "Time"]
168-
for num, lis in enumerate([precision_list, recall_list, f1_score_list, time_list]):
222+
if show_time:
223+
param_list = [precision_list, recall_list, f1_score_list, time_list]
224+
names = ["Precision", "Recall", "F1 score", "Time"]
225+
else:
226+
param_list = [precision_list, recall_list, f1_score_list]
227+
names = ["Precision", "Recall", "F1 score"]
228+
for num, lis in enumerate(param_list):
169229
print(names[num], sum(lis) / len(lis))
170230

171231

@@ -176,8 +236,9 @@ def print_accuracy_sgn():
176236
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
177237
seg_dir = os.path.join(cochlea_dir, "predictions/val_sgn")
178238
baselines = [
239+
"spiner2D",
179240
"cellpose3",
180-
"cellpose-sam",
241+
"cellpose-sam_2025-10",
181242
"distance_unet",
182243
"micro-sam",
183244
"stardist"]
@@ -194,8 +255,8 @@ def print_accuracy_ihc():
194255
seg_dir = os.path.join(cochlea_dir, "predictions/val_ihc")
195256
baselines = [
196257
"cellpose3",
197-
"cellpose-sam",
198-
"distance_unet_v3",
258+
# "cellpose-sam_2025-11",
259+
"distance_unet_v4b",
199260
"micro-sam"]
200261

201262
for baseline in baselines:
@@ -204,9 +265,9 @@ def print_accuracy_ihc():
204265

205266

206267
def main():
207-
eval_all_sgn()
268+
# eval_all_sgn()
208269
eval_all_ihc()
209-
print_accuracy_sgn()
270+
#print_accuracy_sgn()
210271
print_accuracy_ihc()
211272

212273

0 commit comments

Comments
 (0)