|
| 1 | +import os |
| 2 | +from glob import glob |
| 3 | + |
| 4 | +import imageio.v3 as imageio |
| 5 | +import numpy as np |
| 6 | +from skimage.measure import label |
| 7 | +from skimage.segmentation import relabel_sequential |
| 8 | + |
| 9 | +from torch_em.util import load_model |
| 10 | +from torch_em.util.grid_search import DistanceBasedInstanceSegmentation, instance_segmentation_grid_search |
| 11 | + |
| 12 | +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/" # noqa |
| 13 | +MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/" # noqa |
| 14 | + |
| 15 | +GRID_SEARCH_VALUES = { |
| 16 | + "center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7], |
| 17 | + "boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7], |
| 18 | + "distance_smoothing": [0.0, 0.6, 1.0, 1.6], |
| 19 | + "min_size": [0, 100, 250], |
| 20 | +} |
| 21 | + |
| 22 | + |
| 23 | +def preprocess_gt(): |
| 24 | + label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True)) |
| 25 | + for label_path in label_paths: |
| 26 | + seg = imageio.imread(label_path) |
| 27 | + seg = label(seg) |
| 28 | + |
| 29 | + min_size = 750 |
| 30 | + ids, sizes = np.unique(seg, return_counts=True) |
| 31 | + filter_ids = ids[sizes < min_size] |
| 32 | + seg[np.isin(seg, filter_ids)] = 0 |
| 33 | + seg, _, _ = relabel_sequential(seg) |
| 34 | + |
| 35 | + out_path = label_path.replace("_corrected", "_gt") |
| 36 | + imageio.imwrite(out_path, seg) |
| 37 | + |
| 38 | + |
| 39 | +def run_grid_search(): |
| 40 | + image_paths = sorted(glob(os.path.join(ROOT, "**/*lut3*.tif"), recursive=True)) |
| 41 | + label_paths = sorted(glob(os.path.join(ROOT, "**/*_gt.tif"), recursive=True)) |
| 42 | + assert len(image_paths) == len(label_paths), f"{len(image_paths)}, {len(label_paths)}" |
| 43 | + result_dir = "ihc-v4-gs" |
| 44 | + |
| 45 | + block_shape = (96, 256, 256) |
| 46 | + halo = (8, 64, 64) |
| 47 | + model = load_model(MODEL_PATH) |
| 48 | + segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo) |
| 49 | + |
| 50 | + best_kwargs, best_score = instance_segmentation_grid_search( |
| 51 | + segmenter, image_paths, label_paths, result_dir, grid_search_values=GRID_SEARCH_VALUES |
| 52 | + ) |
| 53 | + print("Grid-search result:") |
| 54 | + print(best_kwargs) |
| 55 | + print(best_score) |
| 56 | + |
| 57 | + |
| 58 | +# TODO plot the grid search results |
| 59 | +def evaluate_grid_search(): |
| 60 | + import matplotlib.pyplot as plt |
| 61 | + import pandas as pd |
| 62 | + import seaborn as sns |
| 63 | + |
| 64 | + result_dir = "ihc-v4-gs" |
| 65 | + criterion = "SA50" |
| 66 | + |
| 67 | + gs_files = glob(os.path.join(result_dir, "*.csv")) |
| 68 | + gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files]) |
| 69 | + |
| 70 | + grid_search_parameters = list(GRID_SEARCH_VALUES.keys()) |
| 71 | + |
| 72 | + # Retrieve only the relevant columns and group by the gridsearch columns |
| 73 | + # and compute the mean value |
| 74 | + gs_result = gs_result[grid_search_parameters + [criterion]].reset_index() |
| 75 | + gs_result = gs_result.groupby(grid_search_parameters).mean().reset_index() |
| 76 | + |
| 77 | + # Find the best score and best result |
| 78 | + best_score, best_idx = gs_result[criterion].max(), gs_result[criterion].idxmax() |
| 79 | + best_params = gs_result.iloc[best_idx] |
| 80 | + best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)} |
| 81 | + |
| 82 | + print("Best parameters:") |
| 83 | + print(best_kwargs) |
| 84 | + print("With score:", best_score) |
| 85 | + |
| 86 | + fig, axes = plt.subplots(3) |
| 87 | + for i, (idx, col) in enumerate([ |
| 88 | + ("center_distance_threshold", "boundary_distance_threshold"), |
| 89 | + ("distance_smoothing", "boundary_distance_threshold"), |
| 90 | + ("distance_smoothing", "center_distance_threshold"), |
| 91 | + ]): |
| 92 | + res = gs_result.groupby([idx, col]).mean().reset_index() |
| 93 | + res = res[[idx, col, criterion]] |
| 94 | + res = res.pivot(index=idx, columns=col, values=criterion) |
| 95 | + sns.heatmap(res, cmap="viridis", annot=True, fmt=".2g", cbar_kws={"label": criterion}, ax=axes[i]) |
| 96 | + axes[i].set_xlabel(idx) |
| 97 | + axes[i].set_xlabel(col) |
| 98 | + plt.show() |
| 99 | + |
| 100 | + |
| 101 | +def main(): |
| 102 | + # preprocess_gt() |
| 103 | + # run_grid_search() |
| 104 | + evaluate_grid_search() |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + main() |
0 commit comments