Skip to content

Commit c44d33d

Browse files
Update generalist evaluation
1 parent 09b21bd commit c44d33d

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

finetuning/generalists/evaluate_generalist.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,36 @@
33
from subprocess import run
44

55
from util import evaluate_checkpoint_for_dataset, ALL_DATASETS, EM_DATASETS, LM_DATASETS
6+
from micro_sam.evaluation import default_experiment_settings, get_experiment_setting_name
67

78
EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/generalists"
89
CHECKPOINTS = {
10+
# Vanilla models
911
"vit_b": "/home/nimcpape/.sam_models/sam_vit_b_01ec64.pth",
1012
"vit_h": "/home/nimcpape/.sam_models/sam_vit_h_4b8939.pth",
13+
# Generalist LM models
14+
"vit_b_lm": "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_b/best.pt",
15+
"vit_h_lm": "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_h/best.pt",
16+
# Generalist EM models
17+
"vit_b_em": "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_b/best.pt",
18+
"vit_h_em": "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_h/best.pt",
1119
}
1220

1321

14-
def submit_array_job(model_name, datasets, amg):
22+
def submit_array_job(model_name, datasets):
1523
n_datasets = len(datasets)
1624
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "evaluate_generalist.sbatch", model_name, "--datasets"]
1725
cmd.extend(datasets)
18-
if amg:
19-
cmd.append("--amg")
2026
run(cmd)
2127

2228

23-
def evaluate_dataset_slurm(model_name, dataset, run_amg):
24-
max_num_val_images = None
25-
if run_amg:
26-
if dataset in EM_DATASETS:
27-
run_amg = False
28-
else:
29-
run_amg = True
30-
max_num_val_images = 100
29+
def evaluate_dataset_slurm(model_name, dataset):
30+
if dataset in EM_DATASETS:
31+
run_amg = False
32+
max_num_val_images = None
33+
else:
34+
run_amg = True
35+
max_num_val_images = 64
3136

3237
is_custom_model = model_name not in ("vit_h", "vit_b")
3338
checkpoint = CHECKPOINTS[model_name]
@@ -52,13 +57,29 @@ def _get_datasets(lm, em):
5257
return datasets
5358

5459

60+
def check_computation(model_name, datasets):
61+
prompt_settings = default_experiment_settings()
62+
for ds in datasets:
63+
experiment_folder = os.path.join(EXPERIMENT_ROOT, model_name, ds)
64+
for setting in prompt_settings:
65+
setting_name = get_experiment_setting_name(setting)
66+
expected_path = os.path.join(experiment_folder, "results", f"{setting_name}.csv")
67+
if not os.path.exists(expected_path):
68+
print("Missing results for:", expected_path)
69+
if ds in LM_DATASETS:
70+
expected_path = os.path.join(experiment_folder, "results", "amg.csv")
71+
if not os.path.exists(expected_path):
72+
print("Missing results for:", expected_path)
73+
print("All checks_run")
74+
75+
5576
# evaluation on slurm
5677
def main():
5778
parser = argparse.ArgumentParser()
5879
parser.add_argument("model_name")
80+
parser.add_argument("--check", "-c", action="store_true")
5981
parser.add_argument("--lm", action="store_true")
6082
parser.add_argument("--em", action="store_true")
61-
parser.add_argument("--amg", action="store_true")
6283
parser.add_argument("--datasets", nargs="+")
6384
args = parser.parse_args()
6485

@@ -67,12 +88,16 @@ def main():
6788
datasets = _get_datasets(args.lm, args.em)
6889
assert all(ds in ALL_DATASETS for ds in datasets)
6990

91+
if args.check:
92+
check_computation(args.model_name, datasets)
93+
return
94+
7095
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
7196
if job_id is None: # this is the main script that submits slurm jobs
72-
submit_array_job(args.model_name, datasets, args.amg)
73-
else: # we're in a slurm job and precompute a setting
97+
submit_array_job(args.model_name, datasets)
98+
else: # we're in a slurm job
7499
job_id = int(job_id)
75-
evaluate_dataset_slurm(args.model_name, datasets[job_id], args.amg)
100+
evaluate_dataset_slurm(args.model_name, datasets[job_id])
76101

77102

78103
if __name__ == "__main__":

finetuning/generalists/evaluate_generalist.sbatch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /bin/bash
22
#SBATCH -c 4
33
#SBATCH --mem 48G
4-
#SBATCH -t 720
4+
#SBATCH -t 2800
55
#SBATCH -p grete:shared
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007

finetuning/generalists/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def get_generalist_predictor(checkpoint, model_type, is_custom_model, return_sta
7878
)
7979

8080

81-
# TODO use model comparison func to generate the image data for qualitative comp
8281
def evaluate_checkpoint_for_dataset(
8382
checkpoint, model_type, dataset, experiment_folder,
8483
run_default_evaluation, run_amg, is_custom_model,

0 commit comments

Comments
 (0)