Skip to content

Commit 2c3dfd7

Browse files
Extend generalist evaluation and fix some issues in evaluation logic
1 parent b70f0ae commit 2c3dfd7

File tree

5 files changed

+103
-6
lines changed

5 files changed

+103
-6
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
from glob import glob
3+
4+
import pandas as pd
5+
6+
from evaluate_generalist import EXPERIMENT_ROOT
7+
from util import EM_DATASETS, LM_DATASETS
8+
9+
10+
def get_results(model, ds):
11+
res_folder = os.path.join(EXPERIMENT_ROOT, model, ds, "results")
12+
res_paths = sorted(glob(os.path.join(res_folder, "box", "*.csv"))) +\
13+
sorted(glob(os.path.join(res_folder, "points", "*.csv")))
14+
15+
amg_res = os.path.join(res_folder, "amg.csv")
16+
if os.path.exists(amg_res):
17+
res_paths.append(amg_res)
18+
19+
results = []
20+
for path in res_paths:
21+
prompt_res = pd.read_csv(path)
22+
prompt_name = os.path.splitext(os.path.relpath(path, res_folder))[0]
23+
prompt_res.insert(0, "prompt", [prompt_name])
24+
results.append(prompt_res)
25+
results = pd.concat(results)
26+
results.insert(0, "dataset", results.shape[0] * [ds])
27+
28+
return results
29+
30+
31+
def compile_results(models, datasets, out_path):
32+
results = []
33+
34+
for model in models:
35+
model_results = []
36+
37+
for ds in datasets:
38+
ds_results = get_results(model, ds)
39+
model_results.append(ds_results)
40+
41+
model_results = pd.concat(model_results)
42+
model_results.insert(0, "model", [model] * model_results.shape[0])
43+
results.append(model_results)
44+
45+
results = pd.concat(results)
46+
results.to_csv(out_path, index=False)
47+
48+
49+
def compile_em():
50+
compile_results(
51+
["vit_h", "vit_h_em", "vit_b", "vit_b_em"],
52+
EM_DATASETS,
53+
os.path.join(EXPERIMENT_ROOT, "evaluation-em.csv")
54+
)
55+
56+
57+
# TODO
58+
def compile_lm():
59+
compile_results(
60+
["vit_h", "vit_h_lm", "vit_b", "vit_b_lm"],
61+
LM_DATASETS,
62+
os.path.join(EXPERIMENT_ROOT, "evaluation-lm.csv")
63+
)
64+
65+
66+
def main():
67+
compile_em()
68+
69+
70+
if __name__ == "__main__":
71+
main()

finetuning/generalists/evaluate_generalist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
"vit_b": "/home/nimcpape/.sam_models/sam_vit_b_01ec64.pth",
1212
"vit_h": "/home/nimcpape/.sam_models/sam_vit_h_4b8939.pth",
1313
# Generalist LM models
14-
"vit_b_lm": "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_b/best.pt",
15-
"vit_h_lm": "/scratch-grete/projects/nim00007/sam/models/LM/generalist/v2/vit_h/best.pt",
14+
"vit_b_lm": "/scratch/projects/nim00007/sam/models/LM/generalist/v2/vit_b/best.pt",
15+
"vit_h_lm": "/scratch/projects/nim00007/sam/models/LM/generalist/v2/vit_h/best.pt",
1616
# Generalist EM models
17-
"vit_b_em": "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_b/best.pt",
18-
"vit_h_em": "/scratch-grete/projects/nim00007/sam/models/EM/generalist/v2/vit_h/best.pt",
17+
"vit_b_em": "/scratch/projects/nim00007/sam/models/EM/generalist/v2/vit_b/best.pt",
18+
"vit_h_em": "/scratch/projects/nim00007/sam/models/EM/generalist/v2/vit_h/best.pt",
1919
}
2020

2121

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
from micro_sam.util import export_custom_sam_model
3+
from evaluate_generalist import CHECKPOINTS, EXPERIMENT_ROOT
4+
5+
OUT_ROOT = os.path.join(EXPERIMENT_ROOT, "exported")
6+
os.makedirs(OUT_ROOT, exist_ok=True)
7+
8+
9+
def export_generalist(model):
10+
checkpoint_path = CHECKPOINTS[model]
11+
model_type = model[:5]
12+
save_path = os.path.join(OUT_ROOT, f"{model}.pth")
13+
export_custom_sam_model(checkpoint_path, model_type, save_path)
14+
15+
16+
def main():
17+
export_generalist("vit_b_em")
18+
export_generalist("vit_h_em")
19+
export_generalist("vit_b_lm")
20+
export_generalist("vit_h_lm")
21+
22+
23+
if __name__ == "__main__":
24+
main()

micro_sam/evaluation/evaluation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from elf.evaluation import mean_segmentation_accuracy
10+
from skimage.measure import label
1011
from tqdm import tqdm
1112

1213

@@ -21,6 +22,7 @@ def _run_evaluation(gt_paths, prediction_paths, verbose=True):
2122
assert os.path.exists(pred_path), pred_path
2223

2324
gt = imageio.imread(gt_path)
25+
gt = label(gt)
2426
pred = imageio.imread(pred_path)
2527

2628
msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)

micro_sam/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ def precompute_image_embeddings(
501501
# key signature does not match or is not in the file
502502
if key not in f.attrs or f.attrs[key] != val:
503503
warnings.warn(
504-
f"Embeddings file {save_path} is invalid due to unmatching {key}. "
505-
"Please recompute embeddings in a new file."
504+
f"Embeddings file {save_path} is invalid due to unmatching {key}: "
505+
f"{f.atrs[key]} != {val}.Please recompute embeddings in a new file."
506506
)
507507
if wrong_file_callback is not None:
508508
save_path = wrong_file_callback(save_path)

0 commit comments

Comments
 (0)