1+ import json
12import os
3+ import warnings
24from glob import glob
35from pathlib import Path
46
57import pandas as pd
68from micro_sam .evaluation import (
7- inference , evaluation ,
9+ automatic_mask_generation , inference , evaluation ,
810 default_experiment_settings , get_experiment_setting_name
911)
1012
2224)
2325
2426
25- def get_data_paths (dataset , split ):
26- image_paths = sorted (glob (os .path .join (DATA_ROOT , dataset , split , "image_*.tif" )))
27+ def get_generalist_predictor (checkpoint , model_type , return_state = False ):
28+ with warnings .catch_warnings ():
29+ warnings .simplefilter ("ignore" )
30+ return inference .get_predictor (
31+ checkpoint , model_type = model_type , return_state = return_state , is_custom_model = True
32+ )
33+
34+
35+ def get_data_paths (dataset , split , max_num_images = None ):
36+ image_pattern = os .path .join (DATA_ROOT , dataset , split , "image_*.tif" )
37+ image_paths = sorted (glob (image_pattern ))
2738 gt_paths = sorted (glob (os .path .join (DATA_ROOT , dataset , split , "labels_*.tif" )))
2839 assert len (image_paths ) == len (gt_paths )
29- assert len (image_paths ) > 0
40+ assert len (image_paths ) > 0 , image_pattern
41+ if max_num_images is not None :
42+ image_paths , gt_paths = image_paths [:max_num_images ], gt_paths [:max_num_images ]
3043 return image_paths , gt_paths
3144
3245
3346def evaluate_checkpoint_for_dataset (
3447 checkpoint , model_type , dataset , experiment_folder ,
3548 run_default_evaluation , run_amg , predictor = None ,
49+ max_num_val_images = None ,
3650):
37- """Evaluate a generalist checkpoint for a given dataset
51+ """Evaluate a generalist checkpoint for a given dataset.
3852 """
3953 assert run_default_evaluation or run_amg
4054
4155 prompt_dir = os .path .join (PROMPT_ROOT , dataset )
4256
4357 if predictor is None :
44- predictor = inference . get_predictor (checkpoint , model_type )
58+ predictor = get_generalist_predictor (checkpoint , model_type )
4559 test_image_paths , test_gt_paths = get_data_paths (dataset , "test" )
4660
61+ embedding_dir = os .path .join (experiment_folder , "test" , "embeddings" )
62+ os .makedirs (embedding_dir , exist_ok = True )
63+ result_dir = os .path .join (experiment_folder , "results" )
64+
4765 results = []
4866 if run_default_evaluation :
49- embedding_dir = os .path .join (experiment_folder , "test" , "embeddings" )
50- os .makedirs (embedding_dir , exist_ok = True )
51-
52- result_dir = os .path .join (experiment_folder , "results" )
53-
5467 prompt_settings = default_experiment_settings ()
5568 for setting in prompt_settings :
5669
@@ -75,7 +88,36 @@ def evaluate_checkpoint_for_dataset(
7588 results .append (result )
7689
7790 if run_amg :
78- raise NotImplementedError
91+ val_embedding_dir = os .path .join (experiment_folder , "val" , "embeddings" )
92+ val_result_dir = os .path .join (experiment_folder , "val" , "results" )
93+ os .makedirs (val_embedding_dir , exist_ok = True )
94+
95+ val_image_paths , val_gt_paths = get_data_paths (dataset , "val" , max_num_images = max_num_val_images )
96+ automatic_mask_generation .run_amg_grid_search (
97+ predictor , val_image_paths , val_gt_paths , val_embedding_dir ,
98+ val_result_dir , verbose_gs = True ,
99+ )
100+
101+ best_iou_thresh , best_stability_thresh , _ = automatic_mask_generation .evaluate_amg_grid_search (val_result_dir )
102+ best_settings = {"pred_iou_thresh" : best_iou_thresh , "stability_score_thresh" : best_stability_thresh }
103+ gs_result_path = os .path .join (experiment_folder , "best_gs_params.json" )
104+ with open (gs_result_path , "w" ) as f :
105+ json .dump (best_settings , f )
106+
107+ prediction_dir = os .path .join (experiment_folder , "test" , "amg" )
108+ os .makedirs (prediction_dir , exist_ok = True )
109+ automatic_mask_generation .run_amg_inference (
110+ predictor , test_image_paths , embedding_dir , prediction_dir ,
111+ amg_generate_kwargs = best_settings ,
112+ )
113+
114+ pred_paths = sorted (glob (os .path .join (prediction_dir , "*.tif" )))
115+ result_path = os .path .join (result_dir , "amg.csv" )
116+ os .makedirs (Path (result_path ).parent , exist_ok = True )
117+
118+ result = evaluation .run_evaluation (test_gt_paths , pred_paths , result_path )
119+ result .insert (0 , "setting" , ["amg" ])
120+ results .append (result )
79121
80122 results = pd .concat (results )
81123 results .insert (0 , "dataset" , [dataset ] * results .shape [0 ])
@@ -85,9 +127,10 @@ def evaluate_checkpoint_for_dataset(
85127def evaluate_checkpoint_for_datasets (
86128 checkpoint , model_type , experiment_root , datasets ,
87129 run_default_evaluation , run_amg , predictor = None ,
130+ max_num_val_images = None ,
88131):
89132 if predictor is None :
90- predictor = inference . get_predictor (checkpoint , model_type )
133+ predictor = get_generalist_predictor (checkpoint , model_type )
91134
92135 results = []
93136 for dataset in datasets :
@@ -97,6 +140,7 @@ def evaluate_checkpoint_for_datasets(
97140 None , None , dataset , experiment_folder ,
98141 run_default_evaluation = run_default_evaluation ,
99142 run_amg = run_amg , predictor = predictor ,
143+ max_num_val_images = max_num_val_images ,
100144 )
101145 results .append (result )
102146
0 commit comments