Skip to content

Commit 4c1c8a4

Browse files
Update grid-search plot
1 parent 868bbe4 commit 4c1c8a4

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

scripts/validation/IHCs/grid_search.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-07-for-grid-search"
1313
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa
1414

15+
GRID_SEARCH_VALUES = {
16+
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
17+
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
18+
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
19+
"min_size": [0],
20+
}
21+
1522

1623
def preprocess_gt():
1724
label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True))
@@ -40,14 +47,8 @@ def run_grid_search():
4047
model = load_model(MODEL_PATH)
4148
segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo)
4249

43-
grid_search_values = {
44-
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
45-
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
46-
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
47-
"min_size": [0],
48-
}
4950
best_kwargs, best_score = instance_segmentation_grid_search(
50-
segmenter, image_paths, label_paths, result_dir, grid_search_values=grid_search_values
51+
segmenter, image_paths, label_paths, result_dir, grid_search_values=GRID_SEARCH_VALUES
5152
)
5253
print("Grid-search result:")
5354
print(best_kwargs)
@@ -56,12 +57,50 @@ def run_grid_search():
5657

5758
# TODO plot the grid search results
5859
def evaluate_grid_search():
59-
pass
60+
import matplotlib.pyplot as plt
61+
import pandas as pd
62+
import seaborn as sns
63+
64+
result_dir = "ihc-v4-gs"
65+
criterion = "SA50"
66+
67+
gs_files = glob(os.path.join(result_dir, "*.csv"))
68+
gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
69+
70+
grid_search_parameters = list(GRID_SEARCH_VALUES.keys())
71+
72+
# Retrieve only the relevant columns and group by the gridsearch columns
73+
# and compute the mean value
74+
gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
75+
gs_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
76+
77+
# Find the best score and best result
78+
best_score, best_idx = gs_result[criterion].max(), gs_result[criterion].idxmax()
79+
best_params = gs_result.iloc[best_idx]
80+
best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
81+
82+
print("Best parameters:")
83+
print(best_kwargs)
84+
print("With score:", best_score)
85+
86+
fig, axes = plt.subplots(3)
87+
for i, (idx, col) in enumerate([
88+
("center_distance_threshold", "boundary_distance_threshold"),
89+
("distance_smoothing", "boundary_distance_threshold"),
90+
("distance_smoothing", "center_distance_threshold"),
91+
]):
92+
res = gs_result.groupby([idx, col]).mean().reset_index()
93+
res = res[[idx, col, criterion]]
94+
res = res.pivot(index=idx, columns=col, values=criterion)
95+
sns.heatmap(res, cmap="viridis", annot=True, fmt=".2g", cbar_kws={"label": criterion}, ax=axes[i])
96+
axes[i].set_xlabel(idx)
97+
axes[i].set_xlabel(col)
98+
plt.show()
6099

61100

62101
def main():
63102
# preprocess_gt()
64-
run_grid_search()
103+
# run_grid_search()
65104
evaluate_grid_search()
66105

67106

0 commit comments

Comments
 (0)