|
| 1 | +import multiprocessing as mp |
| 2 | +from typing import Optional, Sequence |
1 | 3 |
|
| 4 | +import h5py |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +from joblib import dump, load |
| 8 | +from sklearn.ensemble import RandomForestClassifier |
2 | 9 |
|
3 | | -# TODO train a classifier on all features and labels stored in h5 |
4 | | -def train_classifier(feature_paths): |
5 | | - pass |
| 10 | +from ..measurements import compute_object_measures |
6 | 11 |
|
7 | 12 |
|
8 | | -# TODO run prediction on a full cochlea |
9 | | -def predict_classifier(): |
10 | | - pass |
| 13 | +def train_classifier(feature_paths: Sequence[str], save_path: str, **rf_kwargs) -> None: |
| 14 | + """Train a random forest classifier on features and labels that were exported via the classification GUI. |
| 15 | +
|
| 16 | + Args: |
| 17 | + feature_paths: The path to the h5 files with features and labels. |
| 18 | + save_path: Where to save the trained random forest. |
| 19 | + rf_kwargs: Keyword arguments for creating the random forest. |
| 20 | + """ |
| 21 | + features, labels = [], [] |
| 22 | + for path in feature_paths: |
| 23 | + with h5py.File(path, "r") as f: |
| 24 | + for name, group in f.items(): |
| 25 | + features.append(group["features"][:]) |
| 26 | + labels.append(group["labels"][:]) |
| 27 | + |
| 28 | + features = np.concatenate(features) |
| 29 | + labels = np.concatenate(labels) |
| 30 | + |
| 31 | + rf = RandomForestClassifier(**rf_kwargs) |
| 32 | + rf.fit(features, labels) |
| 33 | + |
| 34 | + dump(rf, save_path) |
| 35 | + |
| 36 | + |
| 37 | +def predict_classifier( |
| 38 | + rf_path: str, |
| 39 | + image_path: str, |
| 40 | + segmentation_path: str, |
| 41 | + feature_table_path: str, |
| 42 | + segmentation_table_path: Optional[str], |
| 43 | + image_key: Optional[str] = None, |
| 44 | + segmentation_key: Optional[str] = None, |
| 45 | + n_threads: Optional[int] = None, |
| 46 | + feature_set: str = "default", |
| 47 | +) -> pd.DataFrame: |
| 48 | + """Run prediction with a trained classifier on an input volume with associated segmentation. |
| 49 | +
|
| 50 | + Args: |
| 51 | + rf_path: The path to the trained random forest. |
| 52 | + image_path: The path to the image data. |
| 53 | + segmentation_path: The path to the segmentation. |
| 54 | + feature_table_path: The path for the features used for prediction. |
| 55 | + The features will be computed and saved if this table does not exist. |
| 56 | + segmentation_table_path: The path to the segmentation table (in MoBIE format). |
| 57 | + It will be computed on the fly if it is not given. |
| 58 | + image_key: The key / internal path for the image data. Not needed for tif data. |
| 59 | + segmentation_key: The key / internal path for the segmentation data. Not needed for tif data. |
| 60 | + n_threads: The number of threads for parallelization. |
| 61 | + feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + A dataframe with the prediction. It contains the columns 'label_id', 'predictions' and |
| 65 | + 'probs-0', 'probs-1', ... . The latter columns contain the probabilities for the respective class. |
| 66 | + """ |
| 67 | + compute_object_measures( |
| 68 | + image_path=image_path, |
| 69 | + segmentation_path=segmentation_path, |
| 70 | + segmentation_table_path=segmentation_table_path, |
| 71 | + output_table_path=feature_table_path, |
| 72 | + image_key=image_key, |
| 73 | + segmentation_key=segmentation_key, |
| 74 | + n_threads=n_threads, |
| 75 | + feature_set=feature_set, |
| 76 | + ) |
| 77 | + |
| 78 | + features = pd.read_csv(feature_table_path, sep="\t") |
| 79 | + label_ids = features.label_id.values |
| 80 | + features = features.drop(columns=["label_id"]).values |
| 81 | + |
| 82 | + rf = load(rf_path) |
| 83 | + n_threads = mp.cpu_count() if n_threads is None else n_threads |
| 84 | + rf.n_jobs_ = n_threads |
| 85 | + |
| 86 | + probs = rf.predict_proba(features) |
| 87 | + result = {"label_id": label_ids, "prediction": np.argmax(probs, axis=1)} |
| 88 | + result.update({"probs-{i}": probs[:, i] for i in range(probs.shape[1])}) |
| 89 | + return pd.DataFrame(result) |
0 commit comments