Skip to content

Commit 20ac761

Browse files
Implement RF based classificaiton WIP
1 parent 534ab38 commit 20ac761

File tree

5 files changed

+338
-0
lines changed

5 files changed

+338
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .classification_gui import run_classification_gui
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
from multiprocessing import cpu_count
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
import h5py
7+
import imageio.v3 as imageio
8+
import napari
9+
import numpy as np
10+
import pandas as pd
11+
12+
from joblib import dump
13+
from magicgui import magic_factory
14+
from skimage.measure import regionprops_table
15+
16+
import micro_sam.sam_annotator.object_classifier as classifier_util
17+
from micro_sam.object_classification import project_prediction_to_segmentation
18+
from micro_sam.sam_annotator._widgets import _generate_message
19+
20+
IMAGE_LAYER_NAME = None
21+
SEGMENTATION_LAYER_NAME = None
22+
FEATURES = None
23+
SEG_IDS = None
24+
CLASSIFIER = None
25+
LABELS = None
26+
27+
28+
# TODO refactor
29+
def _compute_features(segmentation, image):
30+
features = pd.DataFrame(regionprops_table(
31+
segmentation, image, properties=[
32+
"label", "area", "axis_major_length", "axis_minor_length",
33+
"equivalent_diameter_area", "euler_number", "extent",
34+
"feret_diameter_max", "inertia_tensor_eigvals",
35+
"intensity_max", "intensity_mean", "intensity_min",
36+
"intensity_std", "moments_central",
37+
"moments_weighted", "solidity",
38+
]
39+
))
40+
seg_ids = features.label.values.astype(int)
41+
features = features.drop(columns="label").values
42+
return features, seg_ids
43+
44+
45+
@magic_factory(call_button="Train and predict")
46+
def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None:
47+
global FEATURES, SEG_IDS, CLASSIFIER, LABELS
48+
49+
annotations = viewer.layers["annotations"].data
50+
segmentation = viewer.layers[SEGMENTATION_LAYER_NAME].data
51+
labels = classifier_util._accumulate_labels(segmentation, annotations)
52+
LABELS = labels
53+
54+
if FEATURES is None:
55+
print("Computing features ...")
56+
image = viewer.layers[IMAGE_LAYER_NAME].data
57+
FEATURES, SEG_IDS = _compute_features(segmentation, image)
58+
59+
print("Training random forest ...")
60+
rf = classifier_util._train_rf(FEATURES, labels, n_estimators=200, max_depth=10, n_jobs=cpu_count())
61+
CLASSIFIER = rf
62+
63+
# Run and set the prediction.
64+
print("Run prediction ...")
65+
pred = rf.predict(FEATURES)
66+
prediction_data = project_prediction_to_segmentation(segmentation, pred, SEG_IDS)
67+
viewer.layers["prediction"].data = prediction_data
68+
69+
70+
@magic_factory(call_button="Export Classifier")
71+
def _create_export_rf_widget(export_path: Optional[Path] = None) -> None:
72+
rf = CLASSIFIER
73+
if rf is None:
74+
return _generate_message("error", "You have not run training yet.")
75+
if export_path is None or export_path == "":
76+
return _generate_message("error", "You have to provide an export path.")
77+
# Do we add an extension? .joblib?
78+
dump(rf, export_path)
79+
80+
81+
@magic_factory(call_button="Export Features")
82+
def _create_export_feature_widget(export_path: Optional[Path] = None) -> None:
83+
84+
if FEATURES is None or LABELS is None:
85+
return _generate_message("error", "You have not run training yet.")
86+
if export_path is None or export_path == "":
87+
return _generate_message("error", "You have to provide an export path.")
88+
89+
valid = LABELS != 0
90+
features, labels = FEATURES[valid], LABELS[valid]
91+
92+
export_path = Path(export_path).with_suffix(".h5")
93+
with h5py.File(export_path, "a") as f:
94+
g = f.create_group(IMAGE_LAYER_NAME)
95+
g.create_dataset("features", data=features, compression="lzf")
96+
g.create_dataset("labels", data=labels, compression="lzf")
97+
98+
99+
def run_classification_gui(image_path, segmentation_path, image_name=None, segmentation_name=None):
100+
global IMAGE_LAYER_NAME, SEGMENTATION_LAYER_NAME
101+
102+
image = imageio.imread(image_path)
103+
segmentation = imageio.imread(segmentation_path)
104+
105+
image_name = os.path.basename(image_path) if image_name is None else image_name
106+
segmentation_name = os.path.basename(segmentation_path) if segmentation_name is None else segmentation_name
107+
108+
IMAGE_LAYER_NAME = image_name
109+
SEGMENTATION_LAYER_NAME = segmentation_name
110+
111+
viewer = napari.Viewer()
112+
viewer.add_image(image, name=image_name)
113+
viewer.add_labels(segmentation, name=segmentation_name)
114+
115+
shape = image.shape
116+
viewer.add_labels(name="prediction", data=np.zeros(shape, dtype="uint8"))
117+
viewer.add_labels(name="annotations", data=np.zeros(shape, dtype="uint8"))
118+
119+
# Add the gui elements.
120+
train_widget = _train_and_predict_rf_widget()
121+
rf_export_widget = _create_export_rf_widget()
122+
feature_export_widget = _create_export_feature_widget()
123+
124+
viewer.window.add_dock_widget(train_widget)
125+
viewer.window.add_dock_widget(feature_export_widget)
126+
viewer.window.add_dock_widget(rf_export_widget)
127+
128+
napari.run()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
3+
# TODO train a classifier on all features and labels stored in h5
4+
def train_classifier(feature_paths):
5+
pass
6+
7+
8+
# TODO run prediction on a full cochlea
9+
def predict_classifier():
10+
pass

flamingo_tools/measurements.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def _measure_volume_and_surface(mask, resolution):
2626
return volume, surface
2727

2828

29+
# TODO extend this to also support regionprops featurs,
30+
# maybe spherical harmonics, line profiles, and nucleus (in SGNs, based on thresholding).
31+
# For this, refactor the feature function.
2932
def compute_object_measures_impl(
3033
image: np.typing.ArrayLike,
3134
segmentation: np.typing.ArrayLike,

scripts/check_ihc_seg.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import imageio.v3 as imageio
6+
import napari
7+
import numpy as np
8+
9+
IHC_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_crop"
10+
IHC_SEG = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_seg"
11+
12+
13+
def inspect_all_data():
14+
15+
images = sorted(glob(os.path.join(IHC_ROOT, "**/*.tif"), recursive=True))
16+
segmentations = sorted(glob(os.path.join(IHC_SEG, "**/*.tif"), recursive=True))
17+
18+
skip_names = ["Calretinin"]
19+
20+
for im_path, seg_path in zip(images, segmentations):
21+
print("Loading", im_path)
22+
root, fname = os.path.split(im_path)
23+
folder = os.path.basename(root)
24+
if folder in skip_names:
25+
continue
26+
27+
try:
28+
im = imageio.imread(im_path)
29+
seg = imageio.imread(seg_path).astype("uint32")
30+
31+
v = napari.Viewer()
32+
v.add_image(im)
33+
v.add_labels(seg)
34+
v.title = f"{folder}/{fname}"
35+
napari.run()
36+
except ValueError:
37+
continue
38+
39+
40+
def _require_prediction(im, image_path, with_mask):
41+
model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v2_cochlea_distance_unet_IHC_supervised_2025-05-21" # noqa
42+
43+
root, fname = os.path.split(image_path)
44+
folder = os.path.basename(root)
45+
46+
cache_path = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/predictions/{folder}"
47+
os.makedirs(cache_path, exist_ok=True)
48+
cache_path = os.path.join(cache_path, fname.replace(".tif", ".h5"))
49+
50+
output_key = "pred_masked" if with_mask else "pred"
51+
52+
if os.path.exists(cache_path):
53+
with h5py.File(cache_path, "r") as f:
54+
if output_key in f:
55+
pred = f[output_key][:]
56+
return pred
57+
58+
from torch_em.util import load_model
59+
from torch_em.util.prediction import predict_with_halo
60+
from torch_em.transform.raw import standardize
61+
62+
block_shape = (128, 128, 128)
63+
halo = (16, 32, 32)
64+
if with_mask:
65+
import nifty.tools as nt
66+
67+
mask = np.zeros(im.shape, dtype=bool)
68+
blocking = nt.blocking([0, 0, 0], im.shape, block_shape)
69+
70+
for block_id in range(blocking.numberOfBlocks):
71+
block = blocking.getBlock(block_id)
72+
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
73+
data = im[bb]
74+
max_ = np.percentile(data, 95)
75+
if max_ > 200:
76+
mask[bb] = 1
77+
else:
78+
mask = None
79+
80+
im = standardize(im)
81+
82+
model = load_model(model_path)
83+
84+
pred = predict_with_halo(
85+
im, model, gpu_ids=[0], block_shape=block_shape, halo=halo, preprocess=None, mask=mask
86+
)
87+
88+
with h5py.File(cache_path, "a") as f:
89+
f.create_dataset(output_key, data=pred, compression="lzf")
90+
91+
92+
def check_block_artifacts():
93+
image_path = os.path.join(IHC_ROOT, "Calretinin/M61L_CR_IHC_forannotations_C1.tif")
94+
im = imageio.imread(image_path)
95+
predictions = _require_prediction(im, image_path, with_mask=False)
96+
97+
seg_path = os.path.join(IHC_SEG, "Calretinin/M61L_CR_IHC_forannotations_C1.tif")
98+
seg_old = imageio.imread(seg_path)
99+
100+
v = napari.Viewer()
101+
v.add_image(im)
102+
v.add_image(predictions)
103+
v.add_labels(seg_old)
104+
napari.run()
105+
106+
107+
def _get_ihc_v_sgn_mask(seg, props, threshold, criterion="ratio"):
108+
sgn_ids = props.label[props[criterion] < threshold].values
109+
ihc_ids = props.label[props[criterion] >= threshold].values
110+
111+
ihc_v_sgn = np.zeros_like(seg, dtype="uint32")
112+
ihc_v_sgn[np.isin(seg, ihc_ids)] = 1
113+
ihc_v_sgn[np.isin(seg, sgn_ids)] = 2
114+
115+
return ihc_v_sgn
116+
117+
118+
# Too simple, need to learn this.
119+
def try_filtering():
120+
import pandas as pd
121+
from skimage.measure import regionprops_table
122+
from magicgui import magic_factory
123+
124+
seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif")
125+
seg = imageio.imread(seg_path)
126+
127+
props = regionprops_table(
128+
seg, properties=["label", "area", "axis_major_length", "axis_minor_length"]
129+
)
130+
props = pd.DataFrame(props)
131+
props["ratio"] = props.axis_major_length / props.axis_minor_length
132+
133+
ratio_threshold = 1.5
134+
size_threshold = 5000
135+
ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, ratio_threshold, criterion="ratio")
136+
137+
@magic_factory(
138+
call_button="Update ratio threshold",
139+
threshold={"widget_type": "FloatSlider", "min": 1.0, "max": 5.0, "step": 0.1}
140+
)
141+
def update_ratio_threshold(threshold: float = ratio_threshold):
142+
ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, threshold, criterion="ratio")
143+
v.layers["ihc_v_sgn"].data = ihc_v_sgn
144+
145+
@magic_factory(
146+
call_button="Update size threshold",
147+
threshold={"widget_type": "FloatSlider", "min": 1000, "max": 20_000, "step": 100}
148+
)
149+
def update_size_threshold(threshold: float = size_threshold):
150+
ihc_v_sgn = _get_ihc_v_sgn_mask(seg, props, threshold, criterion="area")
151+
v.layers["ihc_v_sgn"].data = ihc_v_sgn
152+
153+
image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif")
154+
im = imageio.imread(image_path)
155+
156+
v = napari.Viewer()
157+
v.add_image(im)
158+
v.add_labels(seg)
159+
v.add_labels(ihc_v_sgn)
160+
161+
ratio_widget = update_ratio_threshold()
162+
size_widget = update_size_threshold()
163+
v.window.add_dock_widget(ratio_widget, name="Ratio Threshold Slider")
164+
v.window.add_dock_widget(size_widget, name="Size Threshold Slider")
165+
166+
napari.run()
167+
168+
169+
def run_object_classifier():
170+
from flamingo_tools.classification import run_classification_gui
171+
172+
image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif")
173+
seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif")
174+
175+
run_classification_gui(image_path, seg_path, segmentation_name="IHCs")
176+
177+
178+
# From inspection:
179+
# - CR looks quite good, but also shows the blocking artifacts, and some merges:
180+
# Calretinin/M61L_CR_IHC_forannotations_C1.tif (blocking artifacts)
181+
# Calretinin/M63R_CR640_apexIHC_C2.tif (merges, but also weird looking stain)
182+
# Calretinin/M78L_CR488_apexIHC2_C6.tif (background structures are segmented)
183+
# Background is the case for some others too; it segments the hairs.
184+
# - Myo7a, looks good, but as we discussed the stain is not specific
185+
# Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif (good candidate for filtering)
186+
# Myo7a/3.1L_Myo7a_mid_HCAT_reslice_C4.tif (good candidate for filtering)
187+
# - PV: Stain looks quite different, segmentations don't look so good.
188+
def main():
189+
# inspect_all_data()
190+
# check_block_artifacts()
191+
# try_filtering()
192+
run_object_classifier()
193+
194+
195+
if __name__ == "__main__":
196+
main()

0 commit comments

Comments
 (0)