Skip to content

Commit b2715a7

Browse files
Update generalist training and fix some issue in embedding generation
1 parent c3c3845 commit b2715a7

File tree

6 files changed

+93
-28
lines changed

6 files changed

+93
-28
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from util import evaluate_checkpoint_for_datasets
2+
3+
4+
# TODO extend this to run the full evaluation protocol for a generalist.
5+
6+
checkpoint = "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_b/epoch-30.pt"
7+
root = "/scratch-grete/projects/nim00007/sam/experiments/generalists/lm/test"
8+
datasets = ["covid-if"]
9+
10+
evaluate_checkpoint_for_datasets(
11+
checkpoint=checkpoint,
12+
model_type="vit_b",
13+
experiment_root=root,
14+
datasets=datasets,
15+
run_default_evaluation=True,
16+
run_amg=True,
17+
max_num_val_images=10,
18+
)

finetuning/generalists/lm/evaluate_training_evolution.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import argparse
22
import os
3-
import warnings
43
from glob import glob
54

65
import pandas as pd
7-
from micro_sam.util import get_custom_sam_model
8-
from util import evaluate_checkpoint_for_datasets
6+
from util import evaluate_checkpoint_for_datasets, get_generalist_predictor
97

108
CHECKPOINT_ROOT = "/scratch-grete/projects/nim00007/sam/LM/generalist"
119
EXPERIMENT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/generalists/lm"
@@ -26,9 +24,7 @@ def evaluate_training_evolution(model_type):
2624
epochs, results = [], []
2725
for checkpoint in checkpoints:
2826

29-
with warnings.catch_warnings():
30-
warnings.simplefilter("ignore")
31-
predictor, state = get_custom_sam_model(checkpoint, model_type=model_type, return_state=True)
27+
predictor, state = get_generalist_predictor(checkpoint, model_type, return_state=True)
3228
epoch = state["epoch"] + 1
3329

3430
if epoch in epochs:

finetuning/generalists/lm/util.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json
12
import os
3+
import warnings
24
from glob import glob
35
from pathlib import Path
46

57
import pandas as pd
68
from micro_sam.evaluation import (
7-
inference, evaluation,
9+
automatic_mask_generation, inference, evaluation,
810
default_experiment_settings, get_experiment_setting_name
911
)
1012

@@ -22,35 +24,46 @@
2224
)
2325

2426

25-
def get_data_paths(dataset, split):
26-
image_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, split, "image_*.tif")))
27+
def get_generalist_predictor(checkpoint, model_type, return_state=False):
28+
with warnings.catch_warnings():
29+
warnings.simplefilter("ignore")
30+
return inference.get_predictor(
31+
checkpoint, model_type=model_type, return_state=return_state, is_custom_model=True
32+
)
33+
34+
35+
def get_data_paths(dataset, split, max_num_images=None):
36+
image_pattern = os.path.join(DATA_ROOT, dataset, split, "image_*.tif")
37+
image_paths = sorted(glob(image_pattern))
2738
gt_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, split, "labels_*.tif")))
2839
assert len(image_paths) == len(gt_paths)
29-
assert len(image_paths) > 0
40+
assert len(image_paths) > 0, image_pattern
41+
if max_num_images is not None:
42+
image_paths, gt_paths = image_paths[:max_num_images], gt_paths[:max_num_images]
3043
return image_paths, gt_paths
3144

3245

3346
def evaluate_checkpoint_for_dataset(
3447
checkpoint, model_type, dataset, experiment_folder,
3548
run_default_evaluation, run_amg, predictor=None,
49+
max_num_val_images=None,
3650
):
37-
"""Evaluate a generalist checkpoint for a given dataset
51+
"""Evaluate a generalist checkpoint for a given dataset.
3852
"""
3953
assert run_default_evaluation or run_amg
4054

4155
prompt_dir = os.path.join(PROMPT_ROOT, dataset)
4256

4357
if predictor is None:
44-
predictor = inference.get_predictor(checkpoint, model_type)
58+
predictor = get_generalist_predictor(checkpoint, model_type)
4559
test_image_paths, test_gt_paths = get_data_paths(dataset, "test")
4660

61+
embedding_dir = os.path.join(experiment_folder, "test", "embeddings")
62+
os.makedirs(embedding_dir, exist_ok=True)
63+
result_dir = os.path.join(experiment_folder, "results")
64+
4765
results = []
4866
if run_default_evaluation:
49-
embedding_dir = os.path.join(experiment_folder, "test", "embeddings")
50-
os.makedirs(embedding_dir, exist_ok=True)
51-
52-
result_dir = os.path.join(experiment_folder, "results")
53-
5467
prompt_settings = default_experiment_settings()
5568
for setting in prompt_settings:
5669

@@ -75,7 +88,36 @@ def evaluate_checkpoint_for_dataset(
7588
results.append(result)
7689

7790
if run_amg:
78-
raise NotImplementedError
91+
val_embedding_dir = os.path.join(experiment_folder, "val", "embeddings")
92+
val_result_dir = os.path.join(experiment_folder, "val", "results")
93+
os.makedirs(val_embedding_dir, exist_ok=True)
94+
95+
val_image_paths, val_gt_paths = get_data_paths(dataset, "val", max_num_images=max_num_val_images)
96+
automatic_mask_generation.run_amg_grid_search(
97+
predictor, val_image_paths, val_gt_paths, val_embedding_dir,
98+
val_result_dir, verbose_gs=True,
99+
)
100+
101+
best_iou_thresh, best_stability_thresh, _ = automatic_mask_generation.evaluate_amg_grid_search(val_result_dir)
102+
best_settings = {"pred_iou_thresh": best_iou_thresh, "stability_score_thresh": best_stability_thresh}
103+
gs_result_path = os.path.join(experiment_folder, "best_gs_params.json")
104+
with open(gs_result_path, "w") as f:
105+
json.dump(best_settings, f)
106+
107+
prediction_dir = os.path.join(experiment_folder, "test", "amg")
108+
os.makedirs(prediction_dir, exist_ok=True)
109+
automatic_mask_generation.run_amg_inference(
110+
predictor, test_image_paths, embedding_dir, prediction_dir,
111+
amg_generate_kwargs=best_settings,
112+
)
113+
114+
pred_paths = sorted(glob(os.path.join(prediction_dir, "*.tif")))
115+
result_path = os.path.join(result_dir, "amg.csv")
116+
os.makedirs(Path(result_path).parent, exist_ok=True)
117+
118+
result = evaluation.run_evaluation(test_gt_paths, pred_paths, result_path)
119+
result.insert(0, "setting", ["amg"])
120+
results.append(result)
79121

80122
results = pd.concat(results)
81123
results.insert(0, "dataset", [dataset] * results.shape[0])
@@ -85,9 +127,10 @@ def evaluate_checkpoint_for_dataset(
85127
def evaluate_checkpoint_for_datasets(
86128
checkpoint, model_type, experiment_root, datasets,
87129
run_default_evaluation, run_amg, predictor=None,
130+
max_num_val_images=None,
88131
):
89132
if predictor is None:
90-
predictor = inference.get_predictor(checkpoint, model_type)
133+
predictor = get_generalist_predictor(checkpoint, model_type)
91134

92135
results = []
93136
for dataset in datasets:
@@ -97,6 +140,7 @@ def evaluate_checkpoint_for_datasets(
97140
None, None, dataset, experiment_folder,
98141
run_default_evaluation=run_default_evaluation,
99142
run_amg=run_amg, predictor=predictor,
143+
max_num_val_images=max_num_val_images,
100144
)
101145
results.append(result)
102146

micro_sam/evaluation/automatic_mask_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def run_amg_grid_search(
118118
image = imageio.imread(image_path)
119119
gt = imageio.imread(gt_path)
120120

121-
embedding_path = os.path.join(embedding_dir, f"{image_name[:-4]}.zarr")
121+
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
122122
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path)
123123
amg.initialize(image, image_embeddings)
124124

@@ -166,7 +166,7 @@ def run_amg_inference(
166166
assert os.path.exists(image_path), image_path
167167
image = imageio.imread(image_path)
168168

169-
embedding_path = os.path.join(embedding_dir, f"{image_name[:-4]}.zarr")
169+
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
170170
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path)
171171

172172
amg.initialize(image, image_embeddings)

micro_sam/evaluation/inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,24 @@ def _run_inference_with_prompts_for_image(
178178
def get_predictor(
179179
checkpoint_path: Union[str, os.PathLike],
180180
model_type: str,
181-
return_state: bool = False
181+
return_state: bool = False,
182+
is_custom_model: Optional[bool] = None,
182183
) -> SamPredictor:
183184
"""Get the segment anything predictor from an exported or custom checkpoint.
184185
185186
Args:
186187
checkpoint_path: The checkpoint filepath.
187188
model_type: The type of the model, either vit_h, vit_b or vit_l.
188189
return_state: Whether to return the complete state of the checkpoint in addtion to the predictor.
190+
is_custom_model: Whether this is a custom model or not.
189191
Returns:
190192
The segment anything predictor.
191193
"""
192-
# TODO use try-except rather than this construct, so that we don't rely on the checkpoint name
193-
if checkpoint_path.split("/")[-1] == "best.pt": # Finetuned SAM model
194+
# By default we check if the model follows the torch_em checkpint naming scheme to check whether it is a
195+
# custom model or not. This can be over-ridden by passing True or False for is_custom_model.
196+
is_custom_model = checkpoint_path.split("/")[-1] == "best.pt" if is_custom_model is None else is_custom_model
197+
198+
if is_custom_model: # Finetuned SAM model
194199
predictor = util.get_custom_sam_model(
195200
checkpoint_path=checkpoint_path, model_type=model_type, return_state=return_state
196201
)
@@ -217,7 +222,7 @@ def precompute_all_embeddings(
217222
for image_path in tqdm(image_paths, desc="Precompute embeddings"):
218223
image_name = os.path.basename(image_path)
219224
im = imageio.imread(image_path)
220-
embedding_path = os.path.join(embedding_dir, f"{image_name[:-4]}.zarr")
225+
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
221226
util.precompute_image_embeddings(predictor, im, embedding_path)
222227

223228

@@ -384,7 +389,7 @@ def run_inference_with_prompts(
384389
gt = imageio.imread(gt_path).astype("uint32")
385390
gt = relabel_sequential(gt)[0]
386391

387-
embedding_path = os.path.join(embedding_dir, f"{image_name[:-4]}.zarr")
392+
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
388393
image_embeddings = util.precompute_image_embeddings(predictor, im, embedding_path)
389394
util.set_precomputed(predictor, image_embeddings)
390395

micro_sam/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,10 @@ def precompute_image_embeddings(
496496
if "input_size" in f.attrs: # we have computed the embeddings already
497497
# key signature does not match or is not in the file
498498
if key not in f.attrs or f.attrs[key] != val:
499-
warnings.warn(f"Embeddings file is invalid due to unmatching {key}. \
500-
Please recompute embeddings in a new file.")
499+
warnings.warn(
500+
f"Embeddings file {save_path} is invalid due to unmatching {key}."
501+
"Please recompute embeddings in a new file."
502+
)
501503
if wrong_file_callback is not None:
502504
save_path = wrong_file_callback(save_path)
503505
f = zarr.open(save_path, "a")

0 commit comments

Comments
 (0)