1212ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-07-for-grid-search"
1313MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa
1414
15+ GRID_SEARCH_VALUES = {
16+ "center_distance_threshold" : [0.3 , 0.4 , 0.5 , 0.6 , 0.7 ],
17+ "boundary_distance_threshold" : [0.3 , 0.4 , 0.5 , 0.6 , 0.7 ],
18+ "distance_smoothing" : [0.0 , 0.6 , 1.0 , 1.6 ],
19+ "min_size" : [0 ],
20+ }
21+
1522
1623def preprocess_gt ():
1724 label_paths = sorted (glob (os .path .join (ROOT , "**/*_corrected.tif" ), recursive = True ))
@@ -40,14 +47,8 @@ def run_grid_search():
4047 model = load_model (MODEL_PATH )
4148 segmenter = DistanceBasedInstanceSegmentation (model , block_shape = block_shape , halo = halo )
4249
43- grid_search_values = {
44- "center_distance_threshold" : [0.3 , 0.4 , 0.5 , 0.6 , 0.7 ],
45- "boundary_distance_threshold" : [0.3 , 0.4 , 0.5 , 0.6 , 0.7 ],
46- "distance_smoothing" : [0.0 , 0.6 , 1.0 , 1.6 ],
47- "min_size" : [0 ],
48- }
4950 best_kwargs , best_score = instance_segmentation_grid_search (
50- segmenter , image_paths , label_paths , result_dir , grid_search_values = grid_search_values
51+ segmenter , image_paths , label_paths , result_dir , grid_search_values = GRID_SEARCH_VALUES
5152 )
5253 print ("Grid-search result:" )
5354 print (best_kwargs )
@@ -56,12 +57,50 @@ def run_grid_search():
5657
5758# TODO plot the grid search results
5859def evaluate_grid_search ():
59- pass
60+ import matplotlib .pyplot as plt
61+ import pandas as pd
62+ import seaborn as sns
63+
64+ result_dir = "ihc-v4-gs"
65+ criterion = "SA50"
66+
67+ gs_files = glob (os .path .join (result_dir , "*.csv" ))
68+ gs_result = pd .concat ([pd .read_csv (gs_file ) for gs_file in gs_files ])
69+
70+ grid_search_parameters = list (GRID_SEARCH_VALUES .keys ())
71+
72+ # Retrieve only the relevant columns and group by the gridsearch columns
73+ # and compute the mean value
74+ gs_result = gs_result [grid_search_parameters + [criterion ]].reset_index ()
75+ gs_result = gs_result .groupby (grid_search_parameters ).mean ().reset_index ()
76+
77+ # Find the best score and best result
78+ best_score , best_idx = gs_result [criterion ].max (), gs_result [criterion ].idxmax ()
79+ best_params = gs_result .iloc [best_idx ]
80+ best_kwargs = {k : v for k , v in zip (grid_search_parameters , best_params )}
81+
82+ print ("Best parameters:" )
83+ print (best_kwargs )
84+ print ("With score:" , best_score )
85+
86+ fig , axes = plt .subplots (3 )
87+ for i , (idx , col ) in enumerate ([
88+ ("center_distance_threshold" , "boundary_distance_threshold" ),
89+ ("distance_smoothing" , "boundary_distance_threshold" ),
90+ ("distance_smoothing" , "center_distance_threshold" ),
91+ ]):
92+ res = gs_result .groupby ([idx , col ]).mean ().reset_index ()
93+ res = res [[idx , col , criterion ]]
94+ res = res .pivot (index = idx , columns = col , values = criterion )
95+ sns .heatmap (res , cmap = "viridis" , annot = True , fmt = ".2g" , cbar_kws = {"label" : criterion }, ax = axes [i ])
96+ axes [i ].set_xlabel (idx )
97+ axes [i ].set_xlabel (col )
98+ plt .show ()
6099
61100
62101def main ():
63102 # preprocess_gt()
64- run_grid_search ()
103+ # run_grid_search()
65104 evaluate_grid_search ()
66105
67106
0 commit comments