Skip to content

Commit 80d205d

Browse files
Merge pull request #34 from computational-cell-analytics/rf-classification
Implement RF based classification
2 parents 3cdefa4 + 28f9fb7 commit 80d205d

22 files changed

+1030
-162
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .classification_gui import run_classification_gui
2+
from .training_and_prediction import train_classifier, predict_classifier
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
from multiprocessing import cpu_count
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
import h5py
7+
import imageio.v3 as imageio
8+
import napari
9+
import numpy as np
10+
11+
from joblib import dump
12+
from magicgui import magic_factory
13+
14+
import micro_sam.sam_annotator.object_classifier as classifier_util
15+
from micro_sam.object_classification import project_prediction_to_segmentation
16+
from micro_sam.sam_annotator._widgets import _generate_message
17+
18+
from ..measurements import compute_object_measures_impl
19+
20+
IMAGE_LAYER_NAME = None
21+
SEGMENTATION_LAYER_NAME = None
22+
FEATURES = None
23+
SEG_IDS = None
24+
CLASSIFIER = None
25+
LABELS = None
26+
FEATURE_SET = None
27+
28+
29+
def _compute_features(segmentation, image):
30+
features = compute_object_measures_impl(image, segmentation, feature_set=FEATURE_SET)
31+
seg_ids = features.label_id.values.astype(int)
32+
features = features.drop(columns="label_id").values
33+
return features, seg_ids
34+
35+
36+
@magic_factory(call_button="Train and predict")
37+
def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None:
38+
global FEATURES, SEG_IDS, CLASSIFIER, LABELS
39+
40+
annotations = viewer.layers["annotations"].data
41+
segmentation = viewer.layers[SEGMENTATION_LAYER_NAME].data
42+
labels = classifier_util._accumulate_labels(segmentation, annotations)
43+
LABELS = labels
44+
45+
if FEATURES is None:
46+
print("Computing features ...")
47+
image = viewer.layers[IMAGE_LAYER_NAME].data
48+
FEATURES, SEG_IDS = _compute_features(segmentation, image)
49+
50+
print("Training random forest ...")
51+
rf = classifier_util._train_rf(FEATURES, labels, n_estimators=200, max_depth=10, n_jobs=cpu_count())
52+
CLASSIFIER = rf
53+
54+
# Run and set the prediction.
55+
print("Run prediction ...")
56+
pred = rf.predict(FEATURES)
57+
prediction_data = project_prediction_to_segmentation(segmentation, pred, SEG_IDS)
58+
viewer.layers["prediction"].data = prediction_data
59+
60+
61+
@magic_factory(call_button="Export Classifier")
62+
def _create_export_rf_widget(export_path: Optional[Path] = None) -> None:
63+
rf = CLASSIFIER
64+
if rf is None:
65+
return _generate_message("error", "You have not run training yet.")
66+
if export_path is None or export_path == "":
67+
return _generate_message("error", "You have to provide an export path.")
68+
# Do we add an extension? .joblib?
69+
dump(rf, export_path)
70+
71+
72+
@magic_factory(call_button="Export Features")
73+
def _create_export_feature_widget(export_path: Optional[Path] = None) -> None:
74+
75+
if FEATURES is None or LABELS is None:
76+
return _generate_message("error", "You have not run training yet.")
77+
if export_path is None or export_path == "":
78+
return _generate_message("error", "You have to provide an export path.")
79+
80+
valid = LABELS != 0
81+
features, labels = FEATURES[valid], LABELS[valid]
82+
83+
export_path = Path(export_path).with_suffix(".h5")
84+
with h5py.File(export_path, "a") as f:
85+
g = f.create_group(IMAGE_LAYER_NAME)
86+
g.attrs["feature_set"] = FEATURE_SET
87+
g.create_dataset("features", data=features, compression="lzf")
88+
g.create_dataset("labels", data=labels, compression="lzf")
89+
90+
91+
def run_classification_gui(
92+
image_path: str,
93+
segmentation_path: str,
94+
image_name: Optional[str] = None,
95+
segmentation_name: Optional[str] = None,
96+
feature_set: str = "default",
97+
) -> None:
98+
"""Start the classification GUI.
99+
100+
Args:
101+
image_path: The path to the image data.
102+
segmentation_path: The path to the segmentation.
103+
image_name: The name for the image layer. Will use the filename if not given.
104+
segmentation_name: The name of the label layer with the segmentation.
105+
Will use the filename if not given.
106+
feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details.
107+
"""
108+
global IMAGE_LAYER_NAME, SEGMENTATION_LAYER_NAME, FEATURE_SET
109+
110+
image = imageio.imread(image_path)
111+
segmentation = imageio.imread(segmentation_path)
112+
113+
image_name = os.path.basename(image_path) if image_name is None else image_name
114+
segmentation_name = os.path.basename(segmentation_path) if segmentation_name is None else segmentation_name
115+
116+
IMAGE_LAYER_NAME = image_name
117+
SEGMENTATION_LAYER_NAME = segmentation_name
118+
FEATURE_SET = feature_set
119+
120+
viewer = napari.Viewer()
121+
viewer.add_image(image, name=image_name)
122+
viewer.add_labels(segmentation, name=segmentation_name)
123+
124+
shape = image.shape
125+
viewer.add_labels(name="prediction", data=np.zeros(shape, dtype="uint8"))
126+
viewer.add_labels(name="annotations", data=np.zeros(shape, dtype="uint8"))
127+
128+
# Add the gui elements.
129+
train_widget = _train_and_predict_rf_widget()
130+
rf_export_widget = _create_export_rf_widget()
131+
feature_export_widget = _create_export_feature_widget()
132+
133+
viewer.window.add_dock_widget(train_widget)
134+
viewer.window.add_dock_widget(feature_export_widget)
135+
viewer.window.add_dock_widget(rf_export_widget)
136+
137+
napari.run()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import multiprocessing as mp
2+
from typing import Optional, Sequence
3+
4+
import h5py
5+
import numpy as np
6+
import pandas as pd
7+
from joblib import dump, load
8+
from sklearn.ensemble import RandomForestClassifier
9+
10+
from ..measurements import compute_object_measures
11+
12+
13+
def train_classifier(feature_paths: Sequence[str], save_path: str, **rf_kwargs) -> None:
14+
"""Train a random forest classifier on features and labels that were exported via the classification GUI.
15+
16+
Args:
17+
feature_paths: The path to the h5 files with features and labels.
18+
save_path: Where to save the trained random forest.
19+
rf_kwargs: Keyword arguments for creating the random forest.
20+
"""
21+
features, labels = [], []
22+
for path in feature_paths:
23+
with h5py.File(path, "r") as f:
24+
for name, group in f.items():
25+
features.append(group["features"][:])
26+
labels.append(group["labels"][:])
27+
28+
features = np.concatenate(features)
29+
labels = np.concatenate(labels)
30+
31+
rf = RandomForestClassifier(**rf_kwargs)
32+
rf.fit(features, labels)
33+
34+
dump(rf, save_path)
35+
36+
37+
def predict_classifier(
38+
rf_path: str,
39+
image_path: str,
40+
segmentation_path: str,
41+
feature_table_path: str,
42+
segmentation_table_path: Optional[str],
43+
image_key: Optional[str] = None,
44+
segmentation_key: Optional[str] = None,
45+
n_threads: Optional[int] = None,
46+
feature_set: str = "default",
47+
) -> pd.DataFrame:
48+
"""Run prediction with a trained classifier on an input volume with associated segmentation.
49+
50+
Args:
51+
rf_path: The path to the trained random forest.
52+
image_path: The path to the image data.
53+
segmentation_path: The path to the segmentation.
54+
feature_table_path: The path for the features used for prediction.
55+
The features will be computed and saved if this table does not exist.
56+
segmentation_table_path: The path to the segmentation table (in MoBIE format).
57+
It will be computed on the fly if it is not given.
58+
image_key: The key / internal path for the image data. Not needed for tif data.
59+
segmentation_key: The key / internal path for the segmentation data. Not needed for tif data.
60+
n_threads: The number of threads for parallelization.
61+
feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details.
62+
63+
Returns:
64+
A dataframe with the prediction. It contains the columns 'label_id', 'predictions' and
65+
'probs-0', 'probs-1', ... . The latter columns contain the probabilities for the respective class.
66+
"""
67+
compute_object_measures(
68+
image_path=image_path,
69+
segmentation_path=segmentation_path,
70+
segmentation_table_path=segmentation_table_path,
71+
output_table_path=feature_table_path,
72+
image_key=image_key,
73+
segmentation_key=segmentation_key,
74+
n_threads=n_threads,
75+
feature_set=feature_set,
76+
)
77+
78+
features = pd.read_csv(feature_table_path, sep="\t")
79+
label_ids = features.label_id.values
80+
features = features.drop(columns=["label_id"]).values
81+
82+
rf = load(rf_path)
83+
n_threads = mp.cpu_count() if n_threads is None else n_threads
84+
rf.n_jobs_ = n_threads
85+
86+
probs = rf.predict_proba(features)
87+
result = {"label_id": label_ids, "prediction": np.argmax(probs, axis=1)}
88+
result.update({"probs-{i}": probs[:, i] for i in range(probs.shape[1])})
89+
return pd.DataFrame(result)

0 commit comments

Comments
 (0)