|
| 1 | +import argparse |
| 2 | + |
| 3 | +import imageio.v3 as imageio |
| 4 | +import napari |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import seaborn as sns |
| 9 | +from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox |
| 10 | +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg |
| 11 | +from magicgui import magicgui |
| 12 | + |
| 13 | +from elf.parallel.distance_transform import distance_transform |
| 14 | +from elf.parallel.seeded_watershed import seeded_watershed |
| 15 | + |
| 16 | +from flamingo_tools.measurements import compute_object_measures_impl |
| 17 | + |
| 18 | + |
| 19 | +class HistogramWidget(QWidget): |
| 20 | + """Qt widget that draws/updates a histogram for one napari layer.""" |
| 21 | + def __init__(self, statistics, default_stat, bins: int = 32, parent=None): |
| 22 | + super().__init__(parent) |
| 23 | + self.bins = bins |
| 24 | + |
| 25 | + # --- layout ------------------------------------------------------ |
| 26 | + self.fig, self.ax = plt.subplots(figsize=(4, 3), tight_layout=True) |
| 27 | + self.canvas = FigureCanvasQTAgg(self.fig) |
| 28 | + |
| 29 | + # We exclude the label id and the volume / surface measurements. |
| 30 | + self.stat_names = statistics.columns[1:-2] |
| 31 | + self.param_choices = self.stat_names |
| 32 | + |
| 33 | + self.param_box = QComboBox() |
| 34 | + self.param_box.addItems(self.param_choices) |
| 35 | + self.param_box.setCurrentText(default_stat) |
| 36 | + |
| 37 | + self.refresh_btn = QPushButton("Refresh") |
| 38 | + self.refresh_btn.clicked.connect(self.update_hist) |
| 39 | + |
| 40 | + layout = QVBoxLayout() |
| 41 | + layout.addWidget(QLabel("Choose statistic:")) |
| 42 | + layout.addWidget(self.param_box) |
| 43 | + layout.addWidget(self.canvas) |
| 44 | + layout.addWidget(self.refresh_btn) |
| 45 | + self.setLayout(layout) |
| 46 | + |
| 47 | + self.statistics = statistics |
| 48 | + self.update_hist() # initial draw |
| 49 | + |
| 50 | + def update_hist(self): |
| 51 | + """Redraw the histogram.""" |
| 52 | + self.ax.clear() |
| 53 | + |
| 54 | + stat_name = self.param_box.currentText() |
| 55 | + |
| 56 | + data = self.statistics[stat_name] |
| 57 | + # Seaborn version (nicer aesthetics) |
| 58 | + sns.histplot(data, bins=self.bins, ax=self.ax, kde=False) |
| 59 | + |
| 60 | + self.ax.set_xlabel(f"{stat_name} GFP Intensity") |
| 61 | + self.ax.set_ylabel("Count") |
| 62 | + self.canvas.draw_idle() |
| 63 | + |
| 64 | + |
| 65 | +def _create_stat_widget(statistics, default_stat): |
| 66 | + widget = HistogramWidget(statistics, default_stat) |
| 67 | + return widget |
| 68 | + |
| 69 | + |
| 70 | +# Extend via watershed, this could work for a better alignment. |
| 71 | +def _extend_sgns_complex(gfp, sgns): |
| 72 | + # 1.) compute distance to the SGNs to create the background seed. |
| 73 | + print("Compute edt") |
| 74 | + |
| 75 | + # Could use parallel impl |
| 76 | + distance_threshol = 7 |
| 77 | + distances = distance_transform(sgns == 0) |
| 78 | + |
| 79 | + # Erode seeds? |
| 80 | + seeds = sgns.copy() |
| 81 | + bg_seed_id = int(seeds.max()) + 1 |
| 82 | + seeds[distances > distance_threshol] = bg_seed_id |
| 83 | + |
| 84 | + # Dilate to cover everything on the boundary? |
| 85 | + print("Run watershed") |
| 86 | + sgns_extended = seeded_watershed(gfp, markers=seeds) |
| 87 | + sgns_extended[sgns_extended == bg_seed_id] = 0 |
| 88 | + |
| 89 | + v = napari.Viewer() |
| 90 | + v.add_image(gfp) |
| 91 | + v.add_image(distances) |
| 92 | + v.add_labels(seeds) |
| 93 | + v.add_labels(sgns_extended) |
| 94 | + napari.run() |
| 95 | + |
| 96 | + |
| 97 | +# Just dilate by 3 pixels. |
| 98 | +def _extend_sgns_simple(gfp, sgns, dilation): |
| 99 | + block_shape = (128,) * 3 |
| 100 | + halo = (dilation + 2,) * 3 |
| 101 | + |
| 102 | + distances = distance_transform(sgns == 0, block_shape=block_shape, halo=halo, n_threads=8) |
| 103 | + mask = distances < dilation |
| 104 | + |
| 105 | + sgns_extended = np.zeros_like(sgns) |
| 106 | + sgns_extended = seeded_watershed( |
| 107 | + distances, sgns, sgns_extended, block_shape=block_shape, halo=halo, n_threads=8, mask=mask |
| 108 | + ) |
| 109 | + |
| 110 | + return sgns_extended |
| 111 | + |
| 112 | + |
| 113 | +def gfp_annotation(prefix, default_stat="mean"): |
| 114 | + gfp = imageio.imread(f"{prefix}_GFP_resized.tif") |
| 115 | + sgns = imageio.imread(f"{prefix}_SGN_resized_v2.tif") |
| 116 | + pv = imageio.imread(f"{prefix}_PV_resized.tif") |
| 117 | + |
| 118 | + # bb = np.s_[128:-128, 128:-128, 128:-128] |
| 119 | + # gfp, sgns, pv = gfp[bb], sgns[bb], pv[bb] |
| 120 | + # print(gfp.shape) |
| 121 | + |
| 122 | + # Extend the sgns so that they cover the SGN boundaries. |
| 123 | + # sgns_extended = _extend_sgns(gfp, sgns) |
| 124 | + # TODO we need to integrate this directly in the object measurement to efficiently do it at scale. |
| 125 | + sgns_extended = _extend_sgns_simple(gfp, sgns, dilation=4) |
| 126 | + |
| 127 | + # Compute the intensity statistics. |
| 128 | + statistics = compute_object_measures_impl(gfp, sgns_extended) |
| 129 | + |
| 130 | + # Open the napari viewer. |
| 131 | + v = napari.Viewer() |
| 132 | + |
| 133 | + # Add the base layers. |
| 134 | + v.add_image(gfp, name="GFP") |
| 135 | + v.add_image(pv, visible=False, name="PV") |
| 136 | + v.add_labels(sgns, visible=False, name="SGNs") |
| 137 | + v.add_labels(sgns_extended, name="SGNs-extended") |
| 138 | + |
| 139 | + # Add additional layers for intensity coloring and classification |
| 140 | + # data_numerical = np.zeros(gfp.shape, dtype="float32") |
| 141 | + data_labels = np.zeros(gfp.shape, dtype="uint8") |
| 142 | + |
| 143 | + # v.add_image(data_numerical, name="gfp-intensity") |
| 144 | + v.add_labels(data_labels, name="positive-negative") |
| 145 | + |
| 146 | + # Add widgets: |
| 147 | + |
| 148 | + # 1.) The widget for selcting the statistics to be used and displaying the histogram. |
| 149 | + stat_widget = _create_stat_widget(statistics, default_stat) |
| 150 | + |
| 151 | + # 2.) The widget for setting the threshold and updating the positive / negative classification based on it. |
| 152 | + stat_names = stat_widget.stat_names |
| 153 | + step = 1 |
| 154 | + all_values = statistics[stat_names].values |
| 155 | + min_val = all_values.min() |
| 156 | + max_val = all_values.max() |
| 157 | + |
| 158 | + @magicgui( |
| 159 | + threshold={ |
| 160 | + "widget_type": "FloatSlider", |
| 161 | + "label": "Threshold", |
| 162 | + "min": min_val, |
| 163 | + "max": max_val, |
| 164 | + "step": step, |
| 165 | + }, |
| 166 | + call_button="Apply", |
| 167 | + ) |
| 168 | + def threshold_widget(viewer: napari.Viewer, threshold: float = (max_val - min_val) / 2): |
| 169 | + label_ids = statistics.label_id.values |
| 170 | + stat_name = stat_widget.param_box.currentText() |
| 171 | + vals = statistics[stat_name].values |
| 172 | + pos_ids = label_ids[vals >= threshold] |
| 173 | + neg_ids = label_ids[vals <= threshold] |
| 174 | + data_labels = np.zeros(gfp.shape, dtype="uint8") |
| 175 | + data_labels[np.isin(sgns_extended, pos_ids)] = 2 |
| 176 | + data_labels[np.isin(sgns_extended, neg_ids)] = 1 |
| 177 | + viewer.layers["positive-negative"].data = data_labels |
| 178 | + |
| 179 | + threshold_widget.viewer.value = v |
| 180 | + |
| 181 | + # Bind the widgets. |
| 182 | + v.window.add_dock_widget(stat_widget, area="right") |
| 183 | + v.window.add_dock_widget(threshold_widget, area="right") |
| 184 | + stat_widget.setWindowTitle("GFP Histogram") |
| 185 | + |
| 186 | + napari.run() |
| 187 | + |
| 188 | + |
| 189 | +# Cochlea chanel registration quality: |
| 190 | +# - M_LR_000144_L: rough alignment is ok, but specific alignment is a bit poor. |
| 191 | +# - M_LR_000145_L: rough alignment is ok, detailed alignment also ok. |
| 192 | +# - M_LR_000151_R: rough alignment is ok, detailed alignment also ok. |
| 193 | +def main(): |
| 194 | + parser = argparse.ArgumentParser() |
| 195 | + parser.add_argument("prefix") |
| 196 | + args = parser.parse_args() |
| 197 | + |
| 198 | + gfp_annotation(args.prefix) |
| 199 | + |
| 200 | + |
| 201 | +if __name__ == "__main__": |
| 202 | + main() |
0 commit comments