1- import argparse
21import json
32import os
4- import pickle
53import warnings
64
75from glob import glob
86from pathlib import Path
9- from tqdm import tqdm
107
118import pandas as pd
129from micro_sam .evaluation import (
@@ -67,68 +64,25 @@ def get_data_paths(dataset, split, max_num_images=None):
6764 return image_paths , gt_paths
6865
6966
70- def _check_prompts (dataset , settings , expected_len ):
71- prompt_folder = os .path .join (PROMPT_ROOT , dataset )
72-
73- def check_prompt_file (prompt_file ):
74- assert os .path .exists (prompt_file ), prompt_file
75- with open (prompt_file , "rb" ) as f :
76- prompts = pickle .load (f )
77- assert len (prompts ) == expected_len , f"{ len (prompts )} , { expected_len } "
78-
79- for setting in settings :
80- pos , neg = setting ["n_positives" ], setting ["n_negatives" ]
81- prompt_file = os .path .join (prompt_folder , f"points-p{ pos } -n{ neg } .pkl" )
82- if pos == 0 and neg == 0 :
83- prompt_file = os .path .join (prompt_folder , "boxes.pkl" )
84- check_prompt_file (prompt_file )
85-
86- print ("All files checked!" )
87-
88-
89- def check_all_datasets (check_prompts = False ):
90-
91- def check_dataset (dataset ):
92- try :
93- images , _ = get_data_paths (dataset , "test" )
94- except AssertionError as e :
95- print ("Checking test split failed for datasset" , dataset , "due to" , e )
96-
97- if dataset not in LM_DATASETS :
98- return len (images )
99-
100- try :
101- get_data_paths (dataset , "val" )
102- except AssertionError as e :
103- print ("Checking val split failed for datasset" , dataset , "due to" , e )
104-
105- return len (images )
106-
107- settings = default_experiment_settings ()
108- for ds in tqdm (ALL_DATASETS , desc = "Checking datasets" ):
109- n_images = check_dataset (ds )
110- if check_prompts :
111- _check_prompts (ds , settings , n_images )
112- print ("All checks done!" )
113-
114-
11567###
11668# Evaluation functionality
11769###
11870
11971
120- def get_generalist_predictor (checkpoint , model_type , return_state = False ):
72+ def get_generalist_predictor (checkpoint , model_type , is_custom_model , return_state = False ):
12173 with warnings .catch_warnings ():
12274 warnings .simplefilter ("ignore" )
12375 return inference .get_predictor (
124- checkpoint , model_type = model_type , return_state = return_state , is_custom_model = True
76+ checkpoint , model_type = model_type ,
77+ return_state = return_state , is_custom_model = is_custom_model
12578 )
12679
12780
81+ # TODO use model comparison func to generate the image data for qualitative comp
12882def evaluate_checkpoint_for_dataset (
12983 checkpoint , model_type , dataset , experiment_folder ,
130- run_default_evaluation , run_amg , predictor = None ,
131- max_num_val_images = None ,
84+ run_default_evaluation , run_amg , is_custom_model ,
85+ predictor = None , max_num_val_images = None ,
13286):
13387 """Evaluate a generalist checkpoint for a given dataset.
13488 """
@@ -137,7 +91,7 @@ def evaluate_checkpoint_for_dataset(
13791 prompt_dir = os .path .join (PROMPT_ROOT , dataset )
13892
13993 if predictor is None :
140- predictor = get_generalist_predictor (checkpoint , model_type )
94+ predictor = get_generalist_predictor (checkpoint , model_type , is_custom_model )
14195 test_image_paths , test_gt_paths = get_data_paths (dataset , "test" )
14296
14397 embedding_dir = os .path .join (experiment_folder , "test" , "embeddings" )
@@ -208,11 +162,11 @@ def evaluate_checkpoint_for_dataset(
208162
209163def evaluate_checkpoint_for_datasets (
210164 checkpoint , model_type , experiment_root , datasets ,
211- run_default_evaluation , run_amg , predictor = None ,
212- max_num_val_images = None ,
165+ run_default_evaluation , run_amg , is_custom_model ,
166+ predictor = None , max_num_val_images = None ,
213167):
214168 if predictor is None :
215- predictor = get_generalist_predictor (checkpoint , model_type )
169+ predictor = get_generalist_predictor (checkpoint , model_type , is_custom_model )
216170
217171 results = []
218172 for dataset in datasets :
@@ -221,23 +175,9 @@ def evaluate_checkpoint_for_datasets(
221175 result = evaluate_checkpoint_for_dataset (
222176 None , None , dataset , experiment_folder ,
223177 run_default_evaluation = run_default_evaluation ,
224- run_amg = run_amg , predictor = predictor ,
225- max_num_val_images = max_num_val_images ,
178+ run_amg = run_amg , is_custom_model = is_custom_model ,
179+ predictor = predictor , max_num_val_images = max_num_val_images ,
226180 )
227181 results .append (result )
228182
229183 return pd .concat (results )
230-
231-
232- def evaluate_checkpoint_for_datasets_slurm (
233- checkpoint , model_type , experiment_root , datasets ,
234- run_default_evaluation , run_amg ,
235- ):
236- raise NotImplementedError
237-
238-
239- if __name__ == "__main__" :
240- parser = argparse .ArgumentParser ()
241- parser .add_argument ("--check_prompts" , "-c" , action = "store_true" )
242- args = parser .parse_args ()
243- check_all_datasets (args .check_prompts )
0 commit comments