|
| 1 | +import os |
| 2 | +from glob import glob |
| 3 | + |
1 | 4 | from micro_sam.evaluation.inference import run_inference_with_iterative_prompting |
| 5 | +from micro_sam.evaluation.evaluation import run_evaluation |
| 6 | + |
2 | 7 | from util import get_checkpoint, get_paths |
3 | 8 |
|
| 9 | +LIVECELL_GT_ROOT = "/scratch-grete/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images" |
| 10 | +# TODO update to make fit other models |
| 11 | +PREDICTION_ROOT = "./pred_interactive_prompting" |
| 12 | + |
| 13 | + |
| 14 | +def run_interactive_prompting(): |
| 15 | + prediction_root = PREDICTION_ROOT |
4 | 16 |
|
5 | | -def main(): |
6 | 17 | checkpoint, model_type = get_checkpoint("vit_b") |
7 | 18 | image_paths, gt_paths = get_paths() |
8 | 19 |
|
9 | | - prediction_root = "./pred_interactive_prompting" |
10 | | - |
11 | 20 | run_inference_with_iterative_prompting( |
12 | 21 | checkpoint, model_type, image_paths, gt_paths, |
13 | 22 | prediction_root, use_boxes=False, batch_size=16, |
14 | 23 | ) |
15 | 24 |
|
16 | 25 |
|
| 26 | +def get_pg_paths(pred_folder): |
| 27 | + pred_paths = sorted(glob(os.path.join(pred_folder, "*.tif"))) |
| 28 | + names = [os.path.split(path)[1] for path in pred_paths] |
| 29 | + gt_paths = [ |
| 30 | + os.path.join(LIVECELL_GT_ROOT, name.split("_")[0], name) for name in names |
| 31 | + ] |
| 32 | + assert all(os.path.exists(pp) for pp in gt_paths) |
| 33 | + return pred_paths, gt_paths |
| 34 | + |
| 35 | + |
| 36 | +def evaluate_interactive_prompting(): |
| 37 | + prediction_root = PREDICTION_ROOT |
| 38 | + prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*"))) |
| 39 | + for pred_folder in prediction_folders: |
| 40 | + print("Evaluating", pred_folder) |
| 41 | + pred_paths, gt_paths = get_pg_paths(pred_folder) |
| 42 | + res = run_evaluation(gt_paths, pred_paths, save_path=None) |
| 43 | + print(res) |
| 44 | + |
| 45 | + |
| 46 | +def main(): |
| 47 | + # run_interactive_prompting() |
| 48 | + evaluate_interactive_prompting() |
| 49 | + |
| 50 | + |
17 | 51 | if __name__ == "__main__": |
18 | 52 | main() |
0 commit comments