|
| 1 | +import os |
| 2 | +from glob import glob |
| 3 | + |
| 4 | +import h5py |
| 5 | +import imageio.v3 as imageio |
| 6 | +import napari |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +IHC_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_crop" |
| 10 | +IHC_SEG = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_seg" |
| 11 | + |
| 12 | + |
| 13 | +def inspect_all_data(): |
| 14 | + |
| 15 | + images = sorted(glob(os.path.join(IHC_ROOT, "**/*.tif"), recursive=True)) |
| 16 | + segmentations = sorted(glob(os.path.join(IHC_SEG, "**/*.tif"), recursive=True)) |
| 17 | + |
| 18 | + skip_names = ["Calretinin"] |
| 19 | + |
| 20 | + for im_path, seg_path in zip(images, segmentations): |
| 21 | + print("Loading", im_path) |
| 22 | + root, fname = os.path.split(im_path) |
| 23 | + folder = os.path.basename(root) |
| 24 | + if folder in skip_names: |
| 25 | + continue |
| 26 | + |
| 27 | + try: |
| 28 | + im = imageio.imread(im_path) |
| 29 | + seg = imageio.imread(seg_path).astype("uint32") |
| 30 | + |
| 31 | + v = napari.Viewer() |
| 32 | + v.add_image(im) |
| 33 | + v.add_labels(seg) |
| 34 | + v.title = f"{folder}/{fname}" |
| 35 | + napari.run() |
| 36 | + except ValueError: |
| 37 | + continue |
| 38 | + |
| 39 | + |
| 40 | +def _require_prediction(im, image_path, with_mask): |
| 41 | + model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v2_cochlea_distance_unet_IHC_supervised_2025-05-21" # noqa |
| 42 | + |
| 43 | + root, fname = os.path.split(image_path) |
| 44 | + folder = os.path.basename(root) |
| 45 | + |
| 46 | + cache_path = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/predictions/{folder}" |
| 47 | + os.makedirs(cache_path, exist_ok=True) |
| 48 | + cache_path = os.path.join(cache_path, fname.replace(".tif", ".h5")) |
| 49 | + |
| 50 | + output_key = "pred_masked" if with_mask else "pred" |
| 51 | + |
| 52 | + if os.path.exists(cache_path): |
| 53 | + with h5py.File(cache_path, "r") as f: |
| 54 | + if output_key in f: |
| 55 | + pred = f[output_key][:] |
| 56 | + return pred |
| 57 | + |
| 58 | + from torch_em.util import load_model |
| 59 | + from torch_em.util.prediction import predict_with_halo |
| 60 | + from torch_em.transform.raw import standardize |
| 61 | + |
| 62 | + block_shape = (128, 128, 128) |
| 63 | + halo = (16, 32, 32) |
| 64 | + if with_mask: |
| 65 | + import nifty.tools as nt |
| 66 | + |
| 67 | + mask = np.zeros(im.shape, dtype=bool) |
| 68 | + blocking = nt.blocking([0, 0, 0], im.shape, block_shape) |
| 69 | + |
| 70 | + for block_id in range(blocking.numberOfBlocks): |
| 71 | + block = blocking.getBlock(block_id) |
| 72 | + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) |
| 73 | + data = im[bb] |
| 74 | + max_ = np.percentile(data, 95) |
| 75 | + if max_ > 200: |
| 76 | + mask[bb] = 1 |
| 77 | + else: |
| 78 | + mask = None |
| 79 | + |
| 80 | + im = standardize(im) |
| 81 | + |
| 82 | + model = load_model(model_path) |
| 83 | + |
| 84 | + pred = predict_with_halo( |
| 85 | + im, model, gpu_ids=[0], block_shape=block_shape, halo=halo, preprocess=None, mask=mask |
| 86 | + ) |
| 87 | + |
| 88 | + with h5py.File(cache_path, "a") as f: |
| 89 | + f.create_dataset(output_key, data=pred, compression="lzf") |
| 90 | + |
| 91 | + |
| 92 | +def check_block_artifacts(): |
| 93 | + image_path = os.path.join(IHC_ROOT, "Calretinin/M61L_CR_IHC_forannotations_C1.tif") |
| 94 | + im = imageio.imread(image_path) |
| 95 | + predictions = _require_prediction(im, image_path, with_mask=False) |
| 96 | + |
| 97 | + seg_path = os.path.join(IHC_SEG, "Calretinin/M61L_CR_IHC_forannotations_C1.tif") |
| 98 | + seg_old = imageio.imread(seg_path) |
| 99 | + |
| 100 | + v = napari.Viewer() |
| 101 | + v.add_image(im) |
| 102 | + v.add_image(predictions) |
| 103 | + v.add_labels(seg_old) |
| 104 | + napari.run() |
| 105 | + |
| 106 | + |
| 107 | +def _get_ihc_v_sgn_mask(seg, props, threshold, criterion="ratio"): |
| 108 | + sgn_ids = props.label[props[criterion] < threshold].values |
| 109 | + ihc_ids = props.label[props[criterion] >= threshold].values |
| 110 | + |
| 111 | + ihc_v_sgn = np.zeros_like(seg, dtype="uint32") |
| 112 | + ihc_v_sgn[np.isin(seg, ihc_ids)] = 1 |
| 113 | + ihc_v_sgn[np.isin(seg, sgn_ids)] = 2 |
| 114 | + |
| 115 | + return ihc_v_sgn |
| 116 | + |
| 117 | + |
| 118 | +# Too simple, need to learn this. |
| 119 | +def try_filtering(): |
| 120 | + import pandas as pd |
| 121 | + from skimage.measure import regionprops_table |
| 122 | + from magicgui import magic_factory |
| 123 | + |
| 124 | + seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") |
| 125 | + seg = imageio.imread(seg_path) |
| 126 | + |
| 127 | + props = regionprops_table( |
| 128 | + seg, properties=["label", "area", "axis_major_length", "axis_minor_length"] |
| 129 | + ) |
| 130 | + props = pd.DataFrame(props) |
| 131 | + props["ratio"] = props.axis_major_length / props.axis_minor_length |
| 132 | + |
| 133 | + ratio_threshold = 1.5 |
| 134 | + size_threshold = 5000 |
| 135 | + ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, ratio_threshold, criterion="ratio") |
| 136 | + |
| 137 | + @magic_factory( |
| 138 | + call_button="Update ratio threshold", |
| 139 | + threshold={"widget_type": "FloatSlider", "min": 1.0, "max": 5.0, "step": 0.1} |
| 140 | + ) |
| 141 | + def update_ratio_threshold(threshold: float = ratio_threshold): |
| 142 | + ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, threshold, criterion="ratio") |
| 143 | + v.layers["ihc_v_sgn"].data = ihc_v_sgn |
| 144 | + |
| 145 | + @magic_factory( |
| 146 | + call_button="Update size threshold", |
| 147 | + threshold={"widget_type": "FloatSlider", "min": 1000, "max": 20_000, "step": 100} |
| 148 | + ) |
| 149 | + def update_size_threshold(threshold: float = size_threshold): |
| 150 | + ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, threshold, criterion="area") |
| 151 | + v.layers["ihc_v_sgn"].data = ihc_v_sgn |
| 152 | + |
| 153 | + image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") |
| 154 | + im = imageio.imread(image_path) |
| 155 | + |
| 156 | + v = napari.Viewer() |
| 157 | + v.add_image(im) |
| 158 | + v.add_labels(seg) |
| 159 | + v.add_labels(ihc_v_sgn) |
| 160 | + |
| 161 | + ratio_widget = update_ratio_threshold() |
| 162 | + size_widget = update_size_threshold() |
| 163 | + v.window.add_dock_widget(ratio_widget, name="Ratio Threshold Slider") |
| 164 | + v.window.add_dock_widget(size_widget, name="Size Threshold Slider") |
| 165 | + |
| 166 | + napari.run() |
| 167 | + |
| 168 | + |
| 169 | +def run_object_classifier(): |
| 170 | + from flamingo_tools.classification import run_classification_gui |
| 171 | + |
| 172 | + image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") |
| 173 | + seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") |
| 174 | + |
| 175 | + run_classification_gui(image_path, seg_path, segmentation_name="IHCs") |
| 176 | + |
| 177 | + |
| 178 | +# From inspection: |
| 179 | +# - CR looks quite good, but also shows the blocking artifacts, and some merges: |
| 180 | +# Calretinin/M61L_CR_IHC_forannotations_C1.tif (blocking artifacts) |
| 181 | +# Calretinin/M63R_CR640_apexIHC_C2.tif (merges, but also weird looking stain) |
| 182 | +# Calretinin/M78L_CR488_apexIHC2_C6.tif (background structures are segmented) |
| 183 | +# Background is the case for some others too; it segments the hairs. |
| 184 | +# - Myo7a, looks good, but as we discussed the stain is not specific |
| 185 | +# Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif (good candidate for filtering) |
| 186 | +# Myo7a/3.1L_Myo7a_mid_HCAT_reslice_C4.tif (good candidate for filtering) |
| 187 | +# - PV: Stain looks quite different, segmentations don't look so good. |
| 188 | +def main(): |
| 189 | + # inspect_all_data() |
| 190 | + # check_block_artifacts() |
| 191 | + # try_filtering() |
| 192 | + run_object_classifier() |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == "__main__": |
| 196 | + main() |
0 commit comments