44import micro_sam .evaluation .model_comparison as comparison
55import torch_em
66
7- from util import get_data_paths , EM_DATASETS
7+ from util import get_data_paths , EM_DATASETS , LM_DATASETS
88
99OUTPUT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/model_comparison"
1010
@@ -15,44 +15,67 @@ def _get_patch_shape(path):
1515 return patch_shape
1616
1717
18+ def raw_trafo (raw ):
19+ raw = raw .transpose ((2 , 0 , 1 ))
20+ print (raw .shape )
21+ return raw
22+
23+
1824def get_loader (dataset ):
1925 image_paths , gt_paths = get_data_paths (dataset , split = "test" )
2026 image_paths , gt_paths = image_paths [:100 ], gt_paths [:100 ]
2127
28+ with_channels = dataset in ("hpa" , "lizard" )
29+
2230 label_transform = torch_em .transform .label .connected_components
2331 loader = torch_em .default_segmentation_loader (
2432 image_paths , None , gt_paths , None ,
2533 batch_size = 1 , patch_shape = _get_patch_shape (image_paths [0 ]),
2634 shuffle = True , n_samples = 25 , label_transform = label_transform ,
35+ with_channels = with_channels , is_seg_dataset = not with_channels
2736 )
2837 return loader
2938
3039
3140def generate_comparison_for_dataset (dataset , model1 , model2 ):
3241 output_folder = os .path .join (OUTPUT_ROOT , dataset )
3342 if os .path .exists (output_folder ):
34- return
43+ return output_folder
3544 print ("Generate model comparison data for" , dataset )
3645 loader = get_loader (dataset )
3746 comparison .generate_data_for_model_comparison (loader , output_folder , model1 , model2 , n_samples = 25 )
47+ return output_folder
3848
3949
40- # TODO
41- def create_comparison_images ():
42- pass
50+ def create_comparison_images (output_folder , dataset ):
51+ plot_folder = os .path .join (OUTPUT_ROOT , "images" , dataset )
52+ if os .path .exists (plot_folder ):
53+ return
54+ comparison .model_comparison (
55+ output_folder , n_images_per_sample = 25 , min_size = 100 ,
56+ plot_folder = plot_folder , outline_dilation = 1
57+ )
4358
4459
4560def generate_comparison_em ():
4661 model1 = "vit_h"
4762 model2 = "vit_h_em"
4863 for dataset in EM_DATASETS :
49- generate_comparison_for_dataset (dataset , model1 , model2 )
50- create_comparison_images ()
64+ folder = generate_comparison_for_dataset (dataset , model1 , model2 )
65+ create_comparison_images (folder , dataset )
66+
67+
68+ def generate_comparison_lm ():
69+ model1 = "vit_h"
70+ model2 = "vit_h_lm"
71+ for dataset in LM_DATASETS :
72+ folder = generate_comparison_for_dataset (dataset , model1 , model2 )
73+ create_comparison_images (folder , dataset )
5174
5275
5376def main ():
54- # generate_comparison_lm()
55- generate_comparison_em ()
77+ generate_comparison_lm ()
78+ # generate_comparison_em()
5679
5780
5881if __name__ == "__main__" :
0 commit comments