Skip to content

Commit 89d978d

Browse files
Update iterative prompting livecell experiment
1 parent f984d9c commit 89d978d

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed
Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,52 @@
1+
import os
2+
from glob import glob
3+
14
from micro_sam.evaluation.inference import run_inference_with_iterative_prompting
5+
from micro_sam.evaluation.evaluation import run_evaluation
6+
27
from util import get_checkpoint, get_paths
38

9+
LIVECELL_GT_ROOT = "/scratch-grete/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images"
10+
# TODO update to make fit other models
11+
PREDICTION_ROOT = "./pred_interactive_prompting"
12+
13+
14+
def run_interactive_prompting():
15+
prediction_root = PREDICTION_ROOT
416

5-
def main():
617
checkpoint, model_type = get_checkpoint("vit_b")
718
image_paths, gt_paths = get_paths()
819

9-
prediction_root = "./pred_interactive_prompting"
10-
1120
run_inference_with_iterative_prompting(
1221
checkpoint, model_type, image_paths, gt_paths,
1322
prediction_root, use_boxes=False, batch_size=16,
1423
)
1524

1625

26+
def get_pg_paths(pred_folder):
27+
pred_paths = sorted(glob(os.path.join(pred_folder, "*.tif")))
28+
names = [os.path.split(path)[1] for path in pred_paths]
29+
gt_paths = [
30+
os.path.join(LIVECELL_GT_ROOT, name.split("_")[0], name) for name in names
31+
]
32+
assert all(os.path.exists(pp) for pp in gt_paths)
33+
return pred_paths, gt_paths
34+
35+
36+
def evaluate_interactive_prompting():
37+
prediction_root = PREDICTION_ROOT
38+
prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
39+
for pred_folder in prediction_folders:
40+
print("Evaluating", pred_folder)
41+
pred_paths, gt_paths = get_pg_paths(pred_folder)
42+
res = run_evaluation(gt_paths, pred_paths, save_path=None)
43+
print(res)
44+
45+
46+
def main():
47+
# run_interactive_prompting()
48+
evaluate_interactive_prompting()
49+
50+
1751
if __name__ == "__main__":
1852
main()

0 commit comments

Comments
 (0)