Skip to content

Commit 1eda4e6

Browse files
Implement RF classification workflow
1 parent 20ac761 commit 1eda4e6

File tree

6 files changed

+292
-136
lines changed

6 files changed

+292
-136
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .classification_gui import run_classification_gui
2+
from .training_and_prediction import train_classifier, predict_classifier

flamingo_tools/classification/classification_gui.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,29 @@
77
import imageio.v3 as imageio
88
import napari
99
import numpy as np
10-
import pandas as pd
1110

1211
from joblib import dump
1312
from magicgui import magic_factory
14-
from skimage.measure import regionprops_table
1513

1614
import micro_sam.sam_annotator.object_classifier as classifier_util
1715
from micro_sam.object_classification import project_prediction_to_segmentation
1816
from micro_sam.sam_annotator._widgets import _generate_message
1917

18+
from ..measurements import compute_object_measures_impl
19+
2020
IMAGE_LAYER_NAME = None
2121
SEGMENTATION_LAYER_NAME = None
2222
FEATURES = None
2323
SEG_IDS = None
2424
CLASSIFIER = None
2525
LABELS = None
26+
FEATURE_SET = None
2627

2728

28-
# TODO refactor
2929
def _compute_features(segmentation, image):
30-
features = pd.DataFrame(regionprops_table(
31-
segmentation, image, properties=[
32-
"label", "area", "axis_major_length", "axis_minor_length",
33-
"equivalent_diameter_area", "euler_number", "extent",
34-
"feret_diameter_max", "inertia_tensor_eigvals",
35-
"intensity_max", "intensity_mean", "intensity_min",
36-
"intensity_std", "moments_central",
37-
"moments_weighted", "solidity",
38-
]
39-
))
40-
seg_ids = features.label.values.astype(int)
41-
features = features.drop(columns="label").values
30+
features = compute_object_measures_impl(image, segmentation, feature_set=FEATURE_SET)
31+
seg_ids = features.label_id.values.astype(int)
32+
features = features.drop(columns="label_id").values
4233
return features, seg_ids
4334

4435

@@ -92,12 +83,29 @@ def _create_export_feature_widget(export_path: Optional[Path] = None) -> None:
9283
export_path = Path(export_path).with_suffix(".h5")
9384
with h5py.File(export_path, "a") as f:
9485
g = f.create_group(IMAGE_LAYER_NAME)
86+
g.attrs["feature_set"] = FEATURE_SET
9587
g.create_dataset("features", data=features, compression="lzf")
9688
g.create_dataset("labels", data=labels, compression="lzf")
9789

9890

99-
def run_classification_gui(image_path, segmentation_path, image_name=None, segmentation_name=None):
100-
global IMAGE_LAYER_NAME, SEGMENTATION_LAYER_NAME
91+
def run_classification_gui(
92+
image_path: str,
93+
segmentation_path: str,
94+
image_name: Optional[str] = None,
95+
segmentation_name: Optional[str] = None,
96+
feature_set: str = "default",
97+
) -> None:
98+
"""Start the classification GUI.
99+
100+
Args:
101+
image_path: The path to the image data.
102+
segmentation_path: The path to the segmentation.
103+
image_name: The name for the image layer. Will use the filename if not given.
104+
segmentation_name: The name of the label layer with the segmentation.
105+
Will use the filename if not given.
106+
feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details.
107+
"""
108+
global IMAGE_LAYER_NAME, SEGMENTATION_LAYER_NAME, FEATURE_SET
101109

102110
image = imageio.imread(image_path)
103111
segmentation = imageio.imread(segmentation_path)
@@ -107,6 +115,7 @@ def run_classification_gui(image_path, segmentation_path, image_name=None, segme
107115

108116
IMAGE_LAYER_NAME = image_name
109117
SEGMENTATION_LAYER_NAME = segmentation_name
118+
FEATURE_SET = feature_set
110119

111120
viewer = napari.Viewer()
112121
viewer.add_image(image, name=image_name)
Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,89 @@
1+
import multiprocessing as mp
2+
from typing import Optional, Sequence
13

4+
import h5py
5+
import numpy as np
6+
import pandas as pd
7+
from joblib import dump, load
8+
from sklearn.ensemble import RandomForestClassifier
29

3-
# TODO train a classifier on all features and labels stored in h5
4-
def train_classifier(feature_paths):
5-
pass
10+
from ..measurements import compute_object_measures
611

712

8-
# TODO run prediction on a full cochlea
9-
def predict_classifier():
10-
pass
13+
def train_classifier(feature_paths: Sequence[str], save_path: str, **rf_kwargs) -> None:
14+
"""Train a random forest classifier on features and labels that were exported via the classification GUI.
15+
16+
Args:
17+
feature_paths: The path to the h5 files with features and labels.
18+
save_path: Where to save the trained random forest.
19+
rf_kwargs: Keyword arguments for creating the random forest.
20+
"""
21+
features, labels = [], []
22+
for path in feature_paths:
23+
with h5py.File(path, "r") as f:
24+
for name, group in f.items():
25+
features.append(group["features"][:])
26+
labels.append(group["labels"][:])
27+
28+
features = np.concatenate(features)
29+
labels = np.concatenate(labels)
30+
31+
rf = RandomForestClassifier(**rf_kwargs)
32+
rf.fit(features, labels)
33+
34+
dump(rf, save_path)
35+
36+
37+
def predict_classifier(
38+
rf_path: str,
39+
image_path: str,
40+
segmentation_path: str,
41+
feature_table_path: str,
42+
segmentation_table_path: Optional[str],
43+
image_key: Optional[str] = None,
44+
segmentation_key: Optional[str] = None,
45+
n_threads: Optional[int] = None,
46+
feature_set: str = "default",
47+
) -> pd.DataFrame:
48+
"""Run prediction with a trained classifier on an input volume with associated segmentation.
49+
50+
Args:
51+
rf_path: The path to the trained random forest.
52+
image_path: The path to the image data.
53+
segmentation_path: The path to the segmentation.
54+
feature_table_path: The path for the features used for prediction.
55+
The features will be computed and saved if this table does not exist.
56+
segmentation_table_path: The path to the segmentation table (in MoBIE format).
57+
It will be computed on the fly if it is not given.
58+
image_key: The key / internal path for the image data. Not needed for tif data.
59+
segmentation_key: The key / internal path for the segmentation data. Not needed for tif data.
60+
n_threads: The number of threads for parallelization.
61+
feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details.
62+
63+
Returns:
64+
A dataframe with the prediction. It contains the columns 'label_id', 'predictions' and
65+
'probs-0', 'probs-1', ... . The latter columns contain the probabilities for the respective class.
66+
"""
67+
compute_object_measures(
68+
image_path=image_path,
69+
segmentation_path=segmentation_path,
70+
segmentation_table_path=segmentation_table_path,
71+
output_table_path=feature_table_path,
72+
image_key=image_key,
73+
segmentation_key=segmentation_key,
74+
n_threads=n_threads,
75+
feature_set=feature_set,
76+
)
77+
78+
features = pd.read_csv(feature_table_path, sep="\t")
79+
label_ids = features.label_id.values
80+
features = features.drop(columns=["label_id"]).values
81+
82+
rf = load(rf_path)
83+
n_threads = mp.cpu_count() if n_threads is None else n_threads
84+
rf.n_jobs_ = n_threads
85+
86+
probs = rf.predict_proba(features)
87+
result = {"label_id": label_ids, "prediction": np.argmax(probs, axis=1)}
88+
result.update({"probs-{i}": probs[:, i] for i in range(probs.shape[1])})
89+
return pd.DataFrame(result)

0 commit comments

Comments
 (0)