Skip to content

Commit 4c48e68

Browse files
Merge pull request #149 from computational-cell-analytics/generalist-experiments
Support RGB data in micro_sam.evaluation.model_comparison
2 parents b0881d0 + ca834e1 commit 4c48e68

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

finetuning/generalists/generate_model_comparison.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import micro_sam.evaluation.model_comparison as comparison
55
import torch_em
66

7-
from util import get_data_paths, EM_DATASETS
7+
from util import get_data_paths, EM_DATASETS, LM_DATASETS
88

99
OUTPUT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/model_comparison"
1010

@@ -15,44 +15,67 @@ def _get_patch_shape(path):
1515
return patch_shape
1616

1717

18+
def raw_trafo(raw):
19+
raw = raw.transpose((2, 0, 1))
20+
print(raw.shape)
21+
return raw
22+
23+
1824
def get_loader(dataset):
1925
image_paths, gt_paths = get_data_paths(dataset, split="test")
2026
image_paths, gt_paths = image_paths[:100], gt_paths[:100]
2127

28+
with_channels = dataset in ("hpa", "lizard")
29+
2230
label_transform = torch_em.transform.label.connected_components
2331
loader = torch_em.default_segmentation_loader(
2432
image_paths, None, gt_paths, None,
2533
batch_size=1, patch_shape=_get_patch_shape(image_paths[0]),
2634
shuffle=True, n_samples=25, label_transform=label_transform,
35+
with_channels=with_channels, is_seg_dataset=not with_channels
2736
)
2837
return loader
2938

3039

3140
def generate_comparison_for_dataset(dataset, model1, model2):
3241
output_folder = os.path.join(OUTPUT_ROOT, dataset)
3342
if os.path.exists(output_folder):
34-
return
43+
return output_folder
3544
print("Generate model comparison data for", dataset)
3645
loader = get_loader(dataset)
3746
comparison.generate_data_for_model_comparison(loader, output_folder, model1, model2, n_samples=25)
47+
return output_folder
3848

3949

40-
# TODO
41-
def create_comparison_images():
42-
pass
50+
def create_comparison_images(output_folder, dataset):
51+
plot_folder = os.path.join(OUTPUT_ROOT, "images", dataset)
52+
if os.path.exists(plot_folder):
53+
return
54+
comparison.model_comparison(
55+
output_folder, n_images_per_sample=25, min_size=100,
56+
plot_folder=plot_folder, outline_dilation=1
57+
)
4358

4459

4560
def generate_comparison_em():
4661
model1 = "vit_h"
4762
model2 = "vit_h_em"
4863
for dataset in EM_DATASETS:
49-
generate_comparison_for_dataset(dataset, model1, model2)
50-
create_comparison_images()
64+
folder = generate_comparison_for_dataset(dataset, model1, model2)
65+
create_comparison_images(folder, dataset)
66+
67+
68+
def generate_comparison_lm():
69+
model1 = "vit_h"
70+
model2 = "vit_h_lm"
71+
for dataset in LM_DATASETS:
72+
folder = generate_comparison_for_dataset(dataset, model1, model2)
73+
create_comparison_images(folder, dataset)
5174

5275

5376
def main():
54-
# generate_comparison_lm()
55-
generate_comparison_em()
77+
generate_comparison_lm()
78+
# generate_comparison_em()
5679

5780

5881
if __name__ == "__main__":

micro_sam/evaluation/model_comparison.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1,
3535
out_path = os.path.join(output_folder, f"sample_{i}.h5")
3636

3737
im = x.numpy().squeeze()
38+
if im.ndim == 3 and im.shape[0] == 3:
39+
im = im.transpose((1, 2, 0))
40+
3841
gt = y.numpy().squeeze().astype("uint32")
3942
gt = relabel_sequential(gt)[0]
4043

@@ -152,12 +155,13 @@ def _evaluate_samples(f, prefix, min_size):
152155

153156

154157
def _overlay_mask(image, mask):
155-
# TODO add support for RGB inptus
156-
assert image.ndim == 2
158+
assert image.ndim in (2, 3)
157159
# overlay the mask
158-
overlay = np.stack(
159-
[image, image, image]
160-
).transpose((1, 2, 0))
160+
if image.ndim == 2:
161+
overlay = np.stack([image, image, image]).transpose((1, 2, 0))
162+
else:
163+
overlay = image
164+
assert overlay.shape[-1] == 3
161165
mask_overlay = np.zeros_like(overlay)
162166
mask_overlay[mask == 1] = [255, 0, 0]
163167
alpha = 0.6

0 commit comments

Comments
 (0)