Skip to content

Commit 930eca7

Browse files
Update generalist evaluation scripts
1 parent fc6d8db commit 930eca7

File tree

5 files changed

+149
-90
lines changed

5 files changed

+149
-90
lines changed
Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,79 @@
1-
from util import evaluate_checkpoint_for_datasets
1+
import argparse
2+
import os
3+
from subprocess import run
24

5+
from util import evaluate_checkpoint_for_dataset, ALL_DATASETS, EM_DATASETS, LM_DATASETS
36

4-
# TODO extend this to run the full evaluation protocol for a generalist.
7+
EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/generalists"
8+
CHECKPOINTS = {
9+
"vit_b": "/home/nimcpape/.sam_models/sam_vit_b_01ec64.pth",
10+
"vit_h": "/home/nimcpape/.sam_models/sam_vit_h_4b8939.pth",
11+
}
512

6-
checkpoint = "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_b/epoch-30.pt"
7-
root = "/scratch-grete/projects/nim00007/sam/experiments/generalists/lm/test"
8-
datasets = ["covid-if"]
913

10-
evaluate_checkpoint_for_datasets(
11-
checkpoint=checkpoint,
12-
model_type="vit_b",
13-
experiment_root=root,
14-
datasets=datasets,
15-
run_default_evaluation=True,
16-
run_amg=True,
17-
max_num_val_images=10,
18-
)
14+
def submit_array_job(model_name, datasets, amg):
15+
n_datasets = len(datasets)
16+
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "evaluate_generalist.sbatch", model_name, "--datasets"]
17+
cmd.extend(datasets)
18+
if amg:
19+
cmd.append("--amg")
20+
run(cmd)
21+
22+
23+
def evaluate_dataset_slurm(model_name, dataset, run_amg):
24+
max_num_val_images = None
25+
if run_amg:
26+
if dataset in EM_DATASETS:
27+
run_amg = False
28+
else:
29+
run_amg = True
30+
max_num_val_images = 100
31+
32+
is_custom_model = model_name not in ("vit_h", "vit_b")
33+
checkpoint = CHECKPOINTS[model_name]
34+
model_type = model_name[:5]
35+
36+
experiment_folder = os.path.join(EXPERIMENT_ROOT, model_name, dataset)
37+
evaluate_checkpoint_for_dataset(
38+
checkpoint, model_type, dataset, experiment_folder,
39+
run_default_evaluation=True, run_amg=run_amg,
40+
is_custom_model=is_custom_model,
41+
max_num_val_images=max_num_val_images,
42+
)
43+
44+
45+
def _get_datasets(lm, em):
46+
assert lm or em
47+
datasets = []
48+
if lm:
49+
datasets.extend(LM_DATASETS)
50+
if em:
51+
datasets.extend(EM_DATASETS)
52+
return datasets
53+
54+
55+
# evaluation on slurm
56+
def main():
57+
parser = argparse.ArgumentParser()
58+
parser.add_argument("model_name")
59+
parser.add_argument("--lm", action="store_true")
60+
parser.add_argument("--em", action="store_true")
61+
parser.add_argument("--amg", action="store_true")
62+
parser.add_argument("--datasets", nargs="+")
63+
args = parser.parse_args()
64+
65+
datasets = args.datasets
66+
if datasets is None or len(datasets) == 0:
67+
datasets = _get_datasets(args.lm, args.em)
68+
assert all(ds in ALL_DATASETS for ds in datasets)
69+
70+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
71+
if job_id is None: # this is the main script that submits slurm jobs
72+
submit_array_job(args.model_name, datasets, args.amg)
73+
else: # we're in a slurm job and precompute a setting
74+
job_id = int(job_id)
75+
evaluate_dataset_slurm(args.model_name, datasets[job_id], args.amg)
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#! /bin/bash
2+
#SBATCH -c 4
3+
#SBATCH --mem 48G
4+
#SBATCH -t 720
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
8+
9+
source activate sam
10+
python evaluate_generalist.py $@

finetuning/generalists/precompute_prompts.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import argparse
22
import os
3+
import pickle
34

45
from subprocess import run
56

67
import micro_sam.evaluation as evaluation
7-
from util import get_data_paths, ALL_DATASETS
8+
from util import get_data_paths, ALL_DATASETS, LM_DATASETS
9+
from tqdm import tqdm
810

911
PROMPT_ROOT = "/scratch/projects/nim00007/sam/experiments/prompts"
1012

@@ -32,17 +34,63 @@ def submit_array_job():
3234
run(cmd)
3335

3436

37+
def _check_prompts(dataset, settings, expected_len):
38+
prompt_folder = os.path.join(PROMPT_ROOT, dataset)
39+
40+
def check_prompt_file(prompt_file):
41+
assert os.path.exists(prompt_file), prompt_file
42+
with open(prompt_file, "rb") as f:
43+
prompts = pickle.load(f)
44+
assert len(prompts) == expected_len, f"{len(prompts)}, {expected_len}"
45+
46+
for setting in settings:
47+
pos, neg = setting["n_positives"], setting["n_negatives"]
48+
prompt_file = os.path.join(prompt_folder, f"points-p{pos}-n{neg}.pkl")
49+
if pos == 0 and neg == 0:
50+
prompt_file = os.path.join(prompt_folder, "boxes.pkl")
51+
check_prompt_file(prompt_file)
52+
53+
54+
def check_prompts_and_datasets():
55+
56+
def check_dataset(dataset):
57+
try:
58+
images, _ = get_data_paths(dataset, "test")
59+
except AssertionError as e:
60+
print("Checking test split failed for datasset", dataset, "due to", e)
61+
62+
if dataset not in LM_DATASETS:
63+
return len(images)
64+
65+
try:
66+
get_data_paths(dataset, "val")
67+
except AssertionError as e:
68+
print("Checking val split failed for datasset", dataset, "due to", e)
69+
70+
return len(images)
71+
72+
settings = evaluation.default_experiment_settings()
73+
for ds in tqdm(ALL_DATASETS, desc="Checking datasets"):
74+
n_images = check_dataset(ds)
75+
_check_prompts(ds, settings, n_images)
76+
print("All checks done!")
77+
78+
3579
def main():
3680
parser = argparse.ArgumentParser()
3781
parser.add_argument("-d", "--dataset")
82+
parser.add_argument("--check", "-c", action="store_true")
3883
args = parser.parse_args()
84+
85+
if args.check:
86+
check_prompts_and_datasets()
87+
return
88+
3989
if args.dataset is not None:
4090
precompute_prompts(args.dataset)
4191
return
4292

43-
# this will fail if the dataset is invalid
4493
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
45-
4694
if job_id is None: # this is the main script that submits slurm jobs
4795
submit_array_job()
4896
else: # we're in a slurm job and precompute a setting

finetuning/generalists/precompute_prompts.sbatch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /bin/bash
22
#SBATCH -c 4
33
#SBATCH --mem 48G
4-
#SBATCH -t 720
4+
#SBATCH -t 2000
55
#SBATCH -p grete:shared
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007

finetuning/generalists/util.py

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import argparse
21
import json
32
import os
4-
import pickle
53
import warnings
64

75
from glob import glob
86
from pathlib import Path
9-
from tqdm import tqdm
107

118
import pandas as pd
129
from micro_sam.evaluation import (
@@ -67,68 +64,25 @@ def get_data_paths(dataset, split, max_num_images=None):
6764
return image_paths, gt_paths
6865

6966

70-
def _check_prompts(dataset, settings, expected_len):
71-
prompt_folder = os.path.join(PROMPT_ROOT, dataset)
72-
73-
def check_prompt_file(prompt_file):
74-
assert os.path.exists(prompt_file), prompt_file
75-
with open(prompt_file, "rb") as f:
76-
prompts = pickle.load(f)
77-
assert len(prompts) == expected_len, f"{len(prompts)}, {expected_len}"
78-
79-
for setting in settings:
80-
pos, neg = setting["n_positives"], setting["n_negatives"]
81-
prompt_file = os.path.join(prompt_folder, f"points-p{pos}-n{neg}.pkl")
82-
if pos == 0 and neg == 0:
83-
prompt_file = os.path.join(prompt_folder, "boxes.pkl")
84-
check_prompt_file(prompt_file)
85-
86-
print("All files checked!")
87-
88-
89-
def check_all_datasets(check_prompts=False):
90-
91-
def check_dataset(dataset):
92-
try:
93-
images, _ = get_data_paths(dataset, "test")
94-
except AssertionError as e:
95-
print("Checking test split failed for datasset", dataset, "due to", e)
96-
97-
if dataset not in LM_DATASETS:
98-
return len(images)
99-
100-
try:
101-
get_data_paths(dataset, "val")
102-
except AssertionError as e:
103-
print("Checking val split failed for datasset", dataset, "due to", e)
104-
105-
return len(images)
106-
107-
settings = default_experiment_settings()
108-
for ds in tqdm(ALL_DATASETS, desc="Checking datasets"):
109-
n_images = check_dataset(ds)
110-
if check_prompts:
111-
_check_prompts(ds, settings, n_images)
112-
print("All checks done!")
113-
114-
11567
###
11668
# Evaluation functionality
11769
###
11870

11971

120-
def get_generalist_predictor(checkpoint, model_type, return_state=False):
72+
def get_generalist_predictor(checkpoint, model_type, is_custom_model, return_state=False):
12173
with warnings.catch_warnings():
12274
warnings.simplefilter("ignore")
12375
return inference.get_predictor(
124-
checkpoint, model_type=model_type, return_state=return_state, is_custom_model=True
76+
checkpoint, model_type=model_type,
77+
return_state=return_state, is_custom_model=is_custom_model
12578
)
12679

12780

81+
# TODO use model comparison func to generate the image data for qualitative comp
12882
def evaluate_checkpoint_for_dataset(
12983
checkpoint, model_type, dataset, experiment_folder,
130-
run_default_evaluation, run_amg, predictor=None,
131-
max_num_val_images=None,
84+
run_default_evaluation, run_amg, is_custom_model,
85+
predictor=None, max_num_val_images=None,
13286
):
13387
"""Evaluate a generalist checkpoint for a given dataset.
13488
"""
@@ -137,7 +91,7 @@ def evaluate_checkpoint_for_dataset(
13791
prompt_dir = os.path.join(PROMPT_ROOT, dataset)
13892

13993
if predictor is None:
140-
predictor = get_generalist_predictor(checkpoint, model_type)
94+
predictor = get_generalist_predictor(checkpoint, model_type, is_custom_model)
14195
test_image_paths, test_gt_paths = get_data_paths(dataset, "test")
14296

14397
embedding_dir = os.path.join(experiment_folder, "test", "embeddings")
@@ -208,11 +162,11 @@ def evaluate_checkpoint_for_dataset(
208162

209163
def evaluate_checkpoint_for_datasets(
210164
checkpoint, model_type, experiment_root, datasets,
211-
run_default_evaluation, run_amg, predictor=None,
212-
max_num_val_images=None,
165+
run_default_evaluation, run_amg, is_custom_model,
166+
predictor=None, max_num_val_images=None,
213167
):
214168
if predictor is None:
215-
predictor = get_generalist_predictor(checkpoint, model_type)
169+
predictor = get_generalist_predictor(checkpoint, model_type, is_custom_model)
216170

217171
results = []
218172
for dataset in datasets:
@@ -221,23 +175,9 @@ def evaluate_checkpoint_for_datasets(
221175
result = evaluate_checkpoint_for_dataset(
222176
None, None, dataset, experiment_folder,
223177
run_default_evaluation=run_default_evaluation,
224-
run_amg=run_amg, predictor=predictor,
225-
max_num_val_images=max_num_val_images,
178+
run_amg=run_amg, is_custom_model=is_custom_model,
179+
predictor=predictor, max_num_val_images=max_num_val_images,
226180
)
227181
results.append(result)
228182

229183
return pd.concat(results)
230-
231-
232-
def evaluate_checkpoint_for_datasets_slurm(
233-
checkpoint, model_type, experiment_root, datasets,
234-
run_default_evaluation, run_amg,
235-
):
236-
raise NotImplementedError
237-
238-
239-
if __name__ == "__main__":
240-
parser = argparse.ArgumentParser()
241-
parser.add_argument("--check_prompts", "-c", action="store_true")
242-
args = parser.parse_args()
243-
check_all_datasets(args.check_prompts)

0 commit comments

Comments
 (0)