33from subprocess import run
44
55from util import evaluate_checkpoint_for_dataset , ALL_DATASETS , EM_DATASETS , LM_DATASETS
6+ from micro_sam .evaluation import default_experiment_settings , get_experiment_setting_name
67
78EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/generalists"
89CHECKPOINTS = {
10+ # Vanilla models
911 "vit_b" : "/home/nimcpape/.sam_models/sam_vit_b_01ec64.pth" ,
1012 "vit_h" : "/home/nimcpape/.sam_models/sam_vit_h_4b8939.pth" ,
13+ # Generalist LM models
14+ "vit_b_lm" : "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_b/best.pt" ,
15+ "vit_h_lm" : "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_h/best.pt" ,
16+ # Generalist EM models
17+ "vit_b_em" : "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_b/best.pt" ,
18+ "vit_h_em" : "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_h/best.pt" ,
1119}
1220
1321
14- def submit_array_job (model_name , datasets , amg ):
22+ def submit_array_job (model_name , datasets ):
1523 n_datasets = len (datasets )
1624 cmd = ["sbatch" , "-a" , f"0-{ n_datasets - 1 } " , "evaluate_generalist.sbatch" , model_name , "--datasets" ]
1725 cmd .extend (datasets )
18- if amg :
19- cmd .append ("--amg" )
2026 run (cmd )
2127
2228
23- def evaluate_dataset_slurm (model_name , dataset , run_amg ):
24- max_num_val_images = None
25- if run_amg :
26- if dataset in EM_DATASETS :
27- run_amg = False
28- else :
29- run_amg = True
30- max_num_val_images = 100
29+ def evaluate_dataset_slurm (model_name , dataset ):
30+ if dataset in EM_DATASETS :
31+ run_amg = False
32+ max_num_val_images = None
33+ else :
34+ run_amg = True
35+ max_num_val_images = 64
3136
3237 is_custom_model = model_name not in ("vit_h" , "vit_b" )
3338 checkpoint = CHECKPOINTS [model_name ]
@@ -52,13 +57,29 @@ def _get_datasets(lm, em):
5257 return datasets
5358
5459
60+ def check_computation (model_name , datasets ):
61+ prompt_settings = default_experiment_settings ()
62+ for ds in datasets :
63+ experiment_folder = os .path .join (EXPERIMENT_ROOT , model_name , ds )
64+ for setting in prompt_settings :
65+ setting_name = get_experiment_setting_name (setting )
66+ expected_path = os .path .join (experiment_folder , "results" , f"{ setting_name } .csv" )
67+ if not os .path .exists (expected_path ):
68+ print ("Missing results for:" , expected_path )
69+ if ds in LM_DATASETS :
70+ expected_path = os .path .join (experiment_folder , "results" , "amg.csv" )
71+ if not os .path .exists (expected_path ):
72+ print ("Missing results for:" , expected_path )
73+ print ("All checks_run" )
74+
75+
5576# evaluation on slurm
5677def main ():
5778 parser = argparse .ArgumentParser ()
5879 parser .add_argument ("model_name" )
80+ parser .add_argument ("--check" , "-c" , action = "store_true" )
5981 parser .add_argument ("--lm" , action = "store_true" )
6082 parser .add_argument ("--em" , action = "store_true" )
61- parser .add_argument ("--amg" , action = "store_true" )
6283 parser .add_argument ("--datasets" , nargs = "+" )
6384 args = parser .parse_args ()
6485
@@ -67,12 +88,16 @@ def main():
6788 datasets = _get_datasets (args .lm , args .em )
6889 assert all (ds in ALL_DATASETS for ds in datasets )
6990
91+ if args .check :
92+ check_computation (args .model_name , datasets )
93+ return
94+
7095 job_id = os .environ .get ("SLURM_ARRAY_TASK_ID" , None )
7196 if job_id is None : # this is the main script that submits slurm jobs
72- submit_array_job (args .model_name , datasets , args . amg )
73- else : # we're in a slurm job and precompute a setting
97+ submit_array_job (args .model_name , datasets )
98+ else : # we're in a slurm job
7499 job_id = int (job_id )
75- evaluate_dataset_slurm (args .model_name , datasets [job_id ], args . amg )
100+ evaluate_dataset_slurm (args .model_name , datasets [job_id ])
76101
77102
78103if __name__ == "__main__" :
0 commit comments