diff --git a/flamingo_tools/classification/__init__.py b/flamingo_tools/classification/__init__.py new file mode 100644 index 0000000..20081b0 --- /dev/null +++ b/flamingo_tools/classification/__init__.py @@ -0,0 +1,2 @@ +from .classification_gui import run_classification_gui +from .training_and_prediction import train_classifier, predict_classifier diff --git a/flamingo_tools/classification/classification_gui.py b/flamingo_tools/classification/classification_gui.py new file mode 100644 index 0000000..a47f429 --- /dev/null +++ b/flamingo_tools/classification/classification_gui.py @@ -0,0 +1,137 @@ +import os +from multiprocessing import cpu_count +from pathlib import Path +from typing import Optional + +import h5py +import imageio.v3 as imageio +import napari +import numpy as np + +from joblib import dump +from magicgui import magic_factory + +import micro_sam.sam_annotator.object_classifier as classifier_util +from micro_sam.object_classification import project_prediction_to_segmentation +from micro_sam.sam_annotator._widgets import _generate_message + +from ..measurements import compute_object_measures_impl + +IMAGE_LAYER_NAME = None +SEGMENTATION_LAYER_NAME = None +FEATURES = None +SEG_IDS = None +CLASSIFIER = None +LABELS = None +FEATURE_SET = None + + +def _compute_features(segmentation, image): + features = compute_object_measures_impl(image, segmentation, feature_set=FEATURE_SET) + seg_ids = features.label_id.values.astype(int) + features = features.drop(columns="label_id").values + return features, seg_ids + + +@magic_factory(call_button="Train and predict") +def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None: + global FEATURES, SEG_IDS, CLASSIFIER, LABELS + + annotations = viewer.layers["annotations"].data + segmentation = viewer.layers[SEGMENTATION_LAYER_NAME].data + labels = classifier_util._accumulate_labels(segmentation, annotations) + LABELS = labels + + if FEATURES is None: + print("Computing features ...") + image = viewer.layers[IMAGE_LAYER_NAME].data + FEATURES, SEG_IDS = _compute_features(segmentation, image) + + print("Training random forest ...") + rf = classifier_util._train_rf(FEATURES, labels, n_estimators=200, max_depth=10, n_jobs=cpu_count()) + CLASSIFIER = rf + + # Run and set the prediction. + print("Run prediction ...") + pred = rf.predict(FEATURES) + prediction_data = project_prediction_to_segmentation(segmentation, pred, SEG_IDS) + viewer.layers["prediction"].data = prediction_data + + +@magic_factory(call_button="Export Classifier") +def _create_export_rf_widget(export_path: Optional[Path] = None) -> None: + rf = CLASSIFIER + if rf is None: + return _generate_message("error", "You have not run training yet.") + if export_path is None or export_path == "": + return _generate_message("error", "You have to provide an export path.") + # Do we add an extension? .joblib? + dump(rf, export_path) + + +@magic_factory(call_button="Export Features") +def _create_export_feature_widget(export_path: Optional[Path] = None) -> None: + + if FEATURES is None or LABELS is None: + return _generate_message("error", "You have not run training yet.") + if export_path is None or export_path == "": + return _generate_message("error", "You have to provide an export path.") + + valid = LABELS != 0 + features, labels = FEATURES[valid], LABELS[valid] + + export_path = Path(export_path).with_suffix(".h5") + with h5py.File(export_path, "a") as f: + g = f.create_group(IMAGE_LAYER_NAME) + g.attrs["feature_set"] = FEATURE_SET + g.create_dataset("features", data=features, compression="lzf") + g.create_dataset("labels", data=labels, compression="lzf") + + +def run_classification_gui( + image_path: str, + segmentation_path: str, + image_name: Optional[str] = None, + segmentation_name: Optional[str] = None, + feature_set: str = "default", +) -> None: + """Start the classification GUI. + + Args: + image_path: The path to the image data. + segmentation_path: The path to the segmentation. + image_name: The name for the image layer. Will use the filename if not given. + segmentation_name: The name of the label layer with the segmentation. + Will use the filename if not given. + feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details. + """ + global IMAGE_LAYER_NAME, SEGMENTATION_LAYER_NAME, FEATURE_SET + + image = imageio.imread(image_path) + segmentation = imageio.imread(segmentation_path) + + image_name = os.path.basename(image_path) if image_name is None else image_name + segmentation_name = os.path.basename(segmentation_path) if segmentation_name is None else segmentation_name + + IMAGE_LAYER_NAME = image_name + SEGMENTATION_LAYER_NAME = segmentation_name + FEATURE_SET = feature_set + + viewer = napari.Viewer() + viewer.add_image(image, name=image_name) + viewer.add_labels(segmentation, name=segmentation_name) + + shape = image.shape + viewer.add_labels(name="prediction", data=np.zeros(shape, dtype="uint8")) + viewer.add_labels(name="annotations", data=np.zeros(shape, dtype="uint8")) + + # Add the gui elements. + train_widget = _train_and_predict_rf_widget() + rf_export_widget = _create_export_rf_widget() + feature_export_widget = _create_export_feature_widget() + + viewer.window.add_dock_widget(train_widget) + viewer.window.add_dock_widget(feature_export_widget) + viewer.window.add_dock_widget(rf_export_widget) + + napari.run() diff --git a/flamingo_tools/classification/training_and_prediction.py b/flamingo_tools/classification/training_and_prediction.py new file mode 100644 index 0000000..72fdf0f --- /dev/null +++ b/flamingo_tools/classification/training_and_prediction.py @@ -0,0 +1,89 @@ +import multiprocessing as mp +from typing import Optional, Sequence + +import h5py +import numpy as np +import pandas as pd +from joblib import dump, load +from sklearn.ensemble import RandomForestClassifier + +from ..measurements import compute_object_measures + + +def train_classifier(feature_paths: Sequence[str], save_path: str, **rf_kwargs) -> None: + """Train a random forest classifier on features and labels that were exported via the classification GUI. + + Args: + feature_paths: The path to the h5 files with features and labels. + save_path: Where to save the trained random forest. + rf_kwargs: Keyword arguments for creating the random forest. + """ + features, labels = [], [] + for path in feature_paths: + with h5py.File(path, "r") as f: + for name, group in f.items(): + features.append(group["features"][:]) + labels.append(group["labels"][:]) + + features = np.concatenate(features) + labels = np.concatenate(labels) + + rf = RandomForestClassifier(**rf_kwargs) + rf.fit(features, labels) + + dump(rf, save_path) + + +def predict_classifier( + rf_path: str, + image_path: str, + segmentation_path: str, + feature_table_path: str, + segmentation_table_path: Optional[str], + image_key: Optional[str] = None, + segmentation_key: Optional[str] = None, + n_threads: Optional[int] = None, + feature_set: str = "default", +) -> pd.DataFrame: + """Run prediction with a trained classifier on an input volume with associated segmentation. + + Args: + rf_path: The path to the trained random forest. + image_path: The path to the image data. + segmentation_path: The path to the segmentation. + feature_table_path: The path for the features used for prediction. + The features will be computed and saved if this table does not exist. + segmentation_table_path: The path to the segmentation table (in MoBIE format). + It will be computed on the fly if it is not given. + image_key: The key / internal path for the image data. Not needed for tif data. + segmentation_key: The key / internal path for the segmentation data. Not needed for tif data. + n_threads: The number of threads for parallelization. + feature_set: The feature set to use. Refer to `flamingo_tools.measurements.FEATURE_FUNCTIONS` for details. + + Returns: + A dataframe with the prediction. It contains the columns 'label_id', 'predictions' and + 'probs-0', 'probs-1', ... . The latter columns contain the probabilities for the respective class. + """ + compute_object_measures( + image_path=image_path, + segmentation_path=segmentation_path, + segmentation_table_path=segmentation_table_path, + output_table_path=feature_table_path, + image_key=image_key, + segmentation_key=segmentation_key, + n_threads=n_threads, + feature_set=feature_set, + ) + + features = pd.read_csv(feature_table_path, sep="\t") + label_ids = features.label_id.values + features = features.drop(columns=["label_id"]).values + + rf = load(rf_path) + n_threads = mp.cpu_count() if n_threads is None else n_threads + rf.n_jobs_ = n_threads + + probs = rf.predict_proba(features) + result = {"label_id": label_ids, "prediction": np.argmax(probs, axis=1)} + result.update({"probs-{i}": probs[:, i] for i in range(probs.shape[1])}) + return pd.DataFrame(result) diff --git a/flamingo_tools/measurements.py b/flamingo_tools/measurements.py index 442b4fe..a7ed2d4 100644 --- a/flamingo_tools/measurements.py +++ b/flamingo_tools/measurements.py @@ -1,11 +1,13 @@ import multiprocessing as mp +import os from concurrent import futures +from functools import partial from typing import Optional import numpy as np import pandas as pd import trimesh -from skimage.measure import marching_cubes +from skimage.measure import marching_cubes, regionprops_table from tqdm import tqdm from .file_utils import read_image_data @@ -26,12 +28,98 @@ def _measure_volume_and_surface(mask, resolution): return volume, surface +def _get_bounding_box(table, seg_id, resolution, shape): + row = table[table.label_id == seg_id] + + bb_min = np.array([ + row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item() + ]).astype("float32") / resolution + bb_min = np.round(bb_min, 0).astype("int32") + + bb_max = np.array([ + row.bb_max_z.item(), row.bb_max_y.item(), row.bb_max_x.item() + ]).astype("float32") / resolution + bb_max = np.round(bb_max, 0).astype("int32") + + bb = tuple( + slice(max(bmin - 1, 0), min(bmax + 1, sh)) + for bmin, bmax, sh in zip(bb_min, bb_max, shape) + ) + return bb + + +def _default_object_features(seg_id, table, image, segmentation, resolution): + bb = _get_bounding_box(table, seg_id, resolution, image.shape) + + local_image = image[bb] + mask = segmentation[bb] == seg_id + assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." + masked_intensity = local_image[mask] + + # Do the base intensity measurements. + measures = { + "label_id": seg_id, + "mean": np.mean(masked_intensity), + "stdev": np.std(masked_intensity), + "min": np.min(masked_intensity), + "max": np.max(masked_intensity), + "median": np.median(masked_intensity), + } + for percentile in (5, 10, 25, 75, 90, 95): + measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile) + + # Do the volume and surface measurement. + volume, surface = _measure_volume_and_surface(mask, resolution) + measures["volume"] = volume + measures["surface"] = surface + return measures + + +def _regionprops_features(seg_id, table, image, segmentation, resolution): + bb = _get_bounding_box(table, seg_id, resolution, image.shape) + + local_image = image[bb] + local_segmentation = segmentation[bb] + mask = local_segmentation == seg_id + assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." + local_segmentation[~mask] = 0 + + features = regionprops_table( + local_segmentation, local_image, properties=[ + "label", "area", "axis_major_length", "axis_minor_length", + "equivalent_diameter_area", "euler_number", "extent", + "feret_diameter_max", "inertia_tensor_eigvals", + "intensity_max", "intensity_mean", "intensity_min", + "intensity_std", "moments_central", + "moments_weighted", "solidity", + ] + ) + + features["label_id"] = features.pop("label") + return features + + +# Maybe also support: +# - spherical harmonics +# - line profiles +FEATURE_FUNCTIONS = { + "default": _default_object_features, + "skimage": _regionprops_features, +} +"""The different feature functions that are supported in `compute_object_measures` and +that can be selected via the feature_set argument. Currently this supports: +- 'default': The default features which compute standard intensity statistics and volume + surface. +- 'skimage': The scikit image regionprops features. +""" + + def compute_object_measures_impl( image: np.typing.ArrayLike, segmentation: np.typing.ArrayLike, n_threads: Optional[int] = None, resolution: float = 0.38, table: Optional[pd.DataFrame] = None, + feature_set: str = "default", ) -> pd.DataFrame: """Compute simple intensity and morphology measures for each segmented cell in a segmentation. @@ -43,6 +131,7 @@ def compute_object_measures_impl( n_threads: The number of threads to use for computation. resolution: The resolution / voxel size of the data. table: The segmentation table. Will be computed on the fly if it is not given. + feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details. Returns: The table with per object measurements. @@ -50,54 +139,22 @@ def compute_object_measures_impl( if table is None: table = compute_table_on_the_fly(segmentation, resolution=resolution) - def intensity_measures(seg_id): - # Get the bounding box. - row = table[table.label_id == seg_id] - - bb_min = np.array([ - row.bb_min_z.item(), row.bb_min_y.item(), row.bb_min_x.item() - ]).astype("float32") / resolution - bb_min = np.round(bb_min, 0).astype("int32") - - bb_max = np.array([ - row.bb_max_z.item(), row.bb_max_y.item(), row.bb_max_x.item() - ]).astype("float32") / resolution - bb_max = np.round(bb_max, 0).astype("int32") - - bb = tuple( - slice(max(bmin - 1, 0), min(bmax + 1, sh)) - for bmin, bmax, sh in zip(bb_min, bb_max, image.shape) - ) - - local_image = image[bb] - mask = segmentation[bb] == seg_id - assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty." - masked_intensity = local_image[mask] - - # Do the base intensity measurements. - measures = { - "label_id": seg_id, - "mean": np.mean(masked_intensity), - "stdev": np.std(masked_intensity), - "min": np.min(masked_intensity), - "max": np.max(masked_intensity), - "median": np.median(masked_intensity), - } - for percentile in (5, 10, 25, 75, 90, 95): - measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile) - - # Do the volume and surface measurement. - volume, surface = _measure_volume_and_surface(mask, resolution) - measures["volume"] = volume - measures["surface"] = surface - return measures + if feature_set not in FEATURE_FUNCTIONS: + raise ValueError + measure_function = partial( + FEATURE_FUNCTIONS[feature_set], + table=table, + image=image, + segmentation=segmentation, + resolution=resolution + ) seg_ids = table.label_id.values assert len(seg_ids) > 0, "The segmentation table is empty." n_threads = mp.cpu_count() if n_threads is None else n_threads with futures.ThreadPoolExecutor(n_threads) as pool: measures = list(tqdm( - pool.map(intensity_measures, seg_ids), total=len(seg_ids), desc="Compute intensity measures" + pool.map(measure_function, seg_ids), total=len(seg_ids), desc="Compute intensity measures" )) # Create the result table and save it. @@ -110,18 +167,21 @@ def intensity_measures(seg_id): def compute_object_measures( image_path: str, segmentation_path: str, - segmentation_table_path: str, + segmentation_table_path: Optional[str], output_table_path: str, image_key: Optional[str] = None, segmentation_key: Optional[str] = None, n_threads: Optional[int] = None, resolution: float = 0.38, + force: bool = False, + feature_set: str = "default", ) -> None: """Compute simple intensity and morphology measures for each segmented cell in a segmentation. - This computes the mean, standard deviation, minimum, maximum, median and + By default, this computes the mean, standard deviation, minimum, maximum, median and 5th, 10th, 25th, 75th, 90th and 95th percentile of the intensity image per cell, as well as the volume and surface. + Other measurements can be computed by changing the feature_set argument. Args: image_path: The filepath to the image data. Either a tif or hdf5/zarr/n5 file. @@ -132,15 +192,23 @@ def compute_object_measures( segmentation_key: The key (= internal path) for the segmentation data. Not needed for tif. n_threads: The number of threads to use for computation. resolution: The resolution / voxel size of the data. + force: Whether to overwrite an existing output table. + feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details. """ + if os.path.exists(output_table_path) and not force: + return + # First, we load the pre-computed segmentation table from MoBIE. - table = pd.read_csv(segmentation_table_path, sep="\t") + if segmentation_table_path is None: + table = None + else: + table = pd.read_csv(segmentation_table_path, sep="\t") # Then, open the volumes. image = read_image_data(image_path, image_key) segmentation = read_image_data(segmentation_path, segmentation_key) measures = compute_object_measures_impl( - image, segmentation, n_threads, resolution, table=table + image, segmentation, n_threads, resolution, table=table, feature_set=feature_set, ) measures.to_csv(output_table_path, sep="\t", index=False) diff --git a/flamingo_tools/segmentation/nucleus_segmentation.py b/flamingo_tools/segmentation/nucleus_segmentation.py new file mode 100644 index 0000000..a315fc0 --- /dev/null +++ b/flamingo_tools/segmentation/nucleus_segmentation.py @@ -0,0 +1,95 @@ +from concurrent import futures +from multiprocessing import cpu_count +from typing import Optional + +import numpy as np +import pandas as pd +from elf.io import open_file +from scipy.ndimage import binary_opening +from skimage.filters import gaussian, threshold_otsu +from skimage.measure import label +from tqdm import tqdm + +from ..file_utils import read_image_data +from ..measurements import _get_bounding_box +from .postprocessing import compute_table_on_the_fly + + +def _naive_nucleus_segmentation_impl(image, segmentation, table, output, n_threads, resolution): + opening_iterations = 3 + + # Compute the table on the fly if it wasn't given. + if table is None: + table = compute_table_on_the_fly(segmentation, resolution=resolution) + + def segment_nucleus(seg_id): + bb = _get_bounding_box(table, seg_id, resolution, image.shape) + image_local, seg_local = image[bb], segmentation[bb] + mask = seg_local == seg_id + + # Smooth before computing the threshold. + image_local = gaussian(image_local) + # Compute threshold only in the mask. + threshold = threshold_otsu(image_local[mask]) + + nucleus_mask = np.logical_and(image_local < threshold, mask) + nucleus_mask = label(nucleus_mask) + ids, sizes = np.unique(nucleus_mask, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + nucleus_mask = (nucleus_mask == ids[np.argmax(sizes)]) + nucleus_mask = binary_opening(nucleus_mask, iterations=opening_iterations) + output[bb][nucleus_mask] = seg_id + + n_threads = cpu_count() if n_threads is None else n_threads + seg_ids = table.label_id.values + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm(tp.map(segment_nucleus, seg_ids), total=len(seg_ids), desc="Segment nuclei")) + + return output + + +def naive_nucleus_segmentation( + image_path: str, + segmentation_path: str, + segmentation_table_path: Optional[str], + output_path: str, + output_key: str, + image_key: Optional[str] = None, + segmentation_key: Optional[str] = None, + n_threads: Optional[int] = None, + resolution: float = 0.38, +): + """Segment nuclei per object with an otsu threshold. + + This assumes that the nucleus is stained significantly less. + + Args: + image_path: The filepath to the image data. Either a tif or hdf5/zarr/n5 file. + segmentation_path: The filepath to the segmentation data. Either a tif or hdf5/zarr/n5 file. + segmentation_table_path: The path to the segmentation table in MoBIE format. + output_path: The path for saving the nucleus segmentation. + output_key: The key for saving the nucleus segmentation. + image_key: The key (= internal path) for the image data. Not needed fir tif. + segmentation_key: The key (= internal path) for the segmentation data. Not needed for tif. + n_threads: The number of threads to use for computation. + resolution: The resolution / voxel size of the data. + """ + # First, we load the pre-computed segmentation table from MoBIE. + if segmentation_table_path is None: + table = None + else: + table = pd.read_csv(segmentation_table_path, sep="\t") + + # Then, open the volumes. + image = read_image_data(image_path, image_key) + segmentation = read_image_data(segmentation_path, segmentation_key) + + # Create the output volume. + with open_file(output_path, mode="a") as f: + output = f.create_dataset( + output_key, shape=segmentation.shape, dtype=segmentation.dtype, compression="gzip", + chunks=segmentation.chunks + ) + + # And run the nucleus segmentation. + _naive_nucleus_segmentation_impl(image, segmentation, table, output, n_threads, resolution) diff --git a/flamingo_tools/validation.py b/flamingo_tools/validation.py index 6248549..2165d7a 100644 --- a/flamingo_tools/validation.py +++ b/flamingo_tools/validation.py @@ -35,12 +35,10 @@ def parse_annotation_path(annotation_path): return cochlea, slice_id -# TODO enable table component filtering with MoBIE table -# NOTE: the main component is always #1 def fetch_data_for_evaluation( annotation_path: str, cache_path: Optional[str] = None, - seg_name: str = "SGN", + seg_name: str = "SGN_v2", z_extent: int = 0, components_for_postprocessing: Optional[List[int]] = None, ) -> Tuple[np.ndarray, pd.DataFrame]: diff --git a/scripts/ihc_analysis/analyze_myo7a.py b/scripts/ihc_analysis/analyze_myo7a.py new file mode 100644 index 0000000..92218a8 --- /dev/null +++ b/scripts/ihc_analysis/analyze_myo7a.py @@ -0,0 +1,64 @@ +import os +from check_ihc_seg import IHC_ROOT, IHC_SEG + + +def run_object_classifier(): + from flamingo_tools.classification import run_classification_gui + + image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") + seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif") + + run_classification_gui(image_path, seg_path, segmentation_name="IHCs") + + +def train_random_forest(): + from flamingo_tools.classification import train_classifier + + feature_path = "data/features.h5" + save_path = "data/rf.joblib" + train_classifier([feature_path], save_path=save_path) + + +def apply_random_forest(): + from flamingo_tools.classification import predict_classifier + + image_path = os.path.join(IHC_ROOT, "Myo7a/3.1L_Myo7a_mid_HCAT_reslice_C4.tif") + seg_path = os.path.join(IHC_SEG, "Myo7a/3.1L_Myo7a_mid_HCAT_reslice_C4.tif") + + rf_path = "data/rf.joblib" + results = predict_classifier( + rf_path, image_path, seg_path, feature_table_path="data/features.csv", + segmentation_table_path=None, n_threads=4, + ) + + import imageio.v3 as imageio + import napari + import nifty.tools as nt + + image = imageio.imread(image_path) + seg = imageio.imread(seg_path) + + relabel_dict = {label_id: pred + 1 for label_id, pred in zip(results.label_id, results.prediction)} + relabel_dict[0] = 0 + pred = nt.takeDict(relabel_dict, seg) + + v = napari.Viewer() + v.add_image(image) + v.add_labels(seg) + v.add_labels(pred) + napari.run() + + +def main(): + # 1.) Start the classifier GUI to extract features for a random forest. + # run_object_classifier() + + # 2.) Train a random forest on the features. + # train_random_forest() + + # 3.) Apply the random forest to another dataset. + apply_random_forest() + + +if __name__ == "__main__": + main() diff --git a/scripts/ihc_analysis/check_ihc_seg.py b/scripts/ihc_analysis/check_ihc_seg.py new file mode 100644 index 0000000..f592478 --- /dev/null +++ b/scripts/ihc_analysis/check_ihc_seg.py @@ -0,0 +1,134 @@ +import os +from glob import glob + +import h5py +import imageio.v3 as imageio +import napari +import numpy as np + +IHC_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_crop" +IHC_SEG = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/IHC_seg" + + +def inspect_all_data(): + + images = sorted(glob(os.path.join(IHC_ROOT, "**/*.tif"), recursive=True)) + segmentations = sorted(glob(os.path.join(IHC_SEG, "**/*.tif"), recursive=True)) + + skip_names = ["Calretinin"] + + for im_path, seg_path in zip(images, segmentations): + print("Loading", im_path) + root, fname = os.path.split(im_path) + folder = os.path.basename(root) + if folder in skip_names: + continue + + try: + im = imageio.imread(im_path) + seg = imageio.imread(seg_path).astype("uint32") + + v = napari.Viewer() + v.add_image(im) + v.add_labels(seg) + v.title = f"{folder}/{fname}" + napari.run() + except ValueError: + continue + + +def _require_prediction(im, image_path, with_mask): + model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v2_cochlea_distance_unet_IHC_supervised_2025-05-21" # noqa + + root, fname = os.path.split(image_path) + folder = os.path.basename(root) + + cache_path = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/predictions/{folder}" + os.makedirs(cache_path, exist_ok=True) + cache_path = os.path.join(cache_path, fname.replace(".tif", ".h5")) + + output_key = "pred_masked" if with_mask else "pred" + + if os.path.exists(cache_path): + with h5py.File(cache_path, "r") as f: + if output_key in f: + pred = f[output_key][:] + return pred + + from torch_em.util import load_model + from torch_em.util.prediction import predict_with_halo + from torch_em.transform.raw import standardize + + block_shape = (128, 128, 128) + halo = (16, 32, 32) + if with_mask: + import nifty.tools as nt + + mask = np.zeros(im.shape, dtype=bool) + blocking = nt.blocking([0, 0, 0], im.shape, block_shape) + + for block_id in range(blocking.numberOfBlocks): + block = blocking.getBlock(block_id) + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) + data = im[bb] + max_ = np.percentile(data, 95) + if max_ > 200: + mask[bb] = 1 + else: + mask = None + + im = standardize(im) + + model = load_model(model_path) + + pred = predict_with_halo( + im, model, gpu_ids=[0], block_shape=block_shape, halo=halo, preprocess=None, mask=mask + ) + + with h5py.File(cache_path, "a") as f: + f.create_dataset(output_key, data=pred, compression="lzf") + + +def check_block_artifacts(): + image_path = os.path.join(IHC_ROOT, "Calretinin/M61L_CR_IHC_forannotations_C1.tif") + im = imageio.imread(image_path) + predictions = _require_prediction(im, image_path, with_mask=False) + + seg_path = os.path.join(IHC_SEG, "Calretinin/M61L_CR_IHC_forannotations_C1.tif") + seg_old = imageio.imread(seg_path) + + v = napari.Viewer() + v.add_image(im) + v.add_image(predictions) + v.add_labels(seg_old) + napari.run() + + +def _get_ihc_v_sgn_mask(seg, props, threshold, criterion="ratio"): + sgn_ids = props.label[props[criterion] < threshold].values + ihc_ids = props.label[props[criterion] >= threshold].values + + ihc_v_sgn = np.zeros_like(seg, dtype="uint32") + ihc_v_sgn[np.isin(seg, ihc_ids)] = 1 + ihc_v_sgn[np.isin(seg, sgn_ids)] = 2 + + return ihc_v_sgn + + +# From inspection: +# - CR looks quite good, but also shows the blocking artifacts, and some merges: +# Calretinin/M61L_CR_IHC_forannotations_C1.tif (blocking artifacts) +# Calretinin/M63R_CR640_apexIHC_C2.tif (merges, but also weird looking stain) +# Calretinin/M78L_CR488_apexIHC2_C6.tif (background structures are segmented) +# Background is the case for some others too; it segments the hairs. +# - Myo7a, looks good, but as we discussed the stain is not specific +# Myo7a/3.1L_Myo7a_apex_HCAT_reslice_C2.tif (good candidate for filtering) +# Myo7a/3.1L_Myo7a_mid_HCAT_reslice_C4.tif (good candidate for filtering) +# - PV: Stain looks quite different, segmentations don't look so good. +def main(): + inspect_all_data() + # check_block_artifacts() + + +if __name__ == "__main__": + main() diff --git a/scripts/sgn_stain_predictions/check_segmentation.py b/scripts/sgn_stain_predictions/check_segmentation.py index 3a4799c..24a77ec 100644 --- a/scripts/sgn_stain_predictions/check_segmentation.py +++ b/scripts/sgn_stain_predictions/check_segmentation.py @@ -6,7 +6,8 @@ ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" -SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" +SAVE_ROOT1 = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" # noqa +SAVE_ROOT2 = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations_v2" # noqa def main(): @@ -16,18 +17,22 @@ def main(): return print("Visualizing", ff) rel_path = os.path.relpath(ff, ROOT) - seg_path = os.path.join(SAVE_ROOT, rel_path) + seg_path1 = os.path.join(SAVE_ROOT1, rel_path) + seg_path2 = os.path.join(SAVE_ROOT2, rel_path) + print("Load raw") image = imageio.imread(ff) - if os.path.exists(seg_path): - seg = imageio.imread(seg_path) - else: - seg = None + print("Load segmentation 1") + seg1 = imageio.imread(seg_path1) if os.path.exists(seg_path1) else None + print("Load segmentation 2") + seg2 = imageio.imread(seg_path2) if os.path.exists(seg_path2) else None v = napari.Viewer() v.add_image(image) - if seg is not None: - v.add_labels(seg) + if seg1 is not None: + v.add_labels(seg1, name="original") + if seg2 is not None: + v.add_labels(seg2, name="adapted") napari.run() diff --git a/scripts/sgn_stain_predictions/measure_intensities.py b/scripts/sgn_stain_predictions/measure_intensities.py index 99f67ac..0745411 100644 --- a/scripts/sgn_stain_predictions/measure_intensities.py +++ b/scripts/sgn_stain_predictions/measure_intensities.py @@ -6,12 +6,12 @@ ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" -SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" +SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations_v2" # noqa def measure_intensities(ff): rel_path = os.path.relpath(ff, ROOT) - out_path = os.path.join("./measurements", rel_path.replace(".tif", ".xlsx")) + out_path = os.path.join("./measurements_v2", rel_path.replace(".tif", ".xlsx")) if os.path.exists(out_path): return diff --git a/scripts/sgn_stain_predictions/run_prediction.py b/scripts/sgn_stain_predictions/run_prediction.py index 1badae3..c46d5f5 100644 --- a/scripts/sgn_stain_predictions/run_prediction.py +++ b/scripts/sgn_stain_predictions/run_prediction.py @@ -7,9 +7,10 @@ from flamingo_tools.segmentation import run_unet_prediction ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" -MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa +# MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa +MODEL_PATH = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/training/sgn_model.pt" # noqa -SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations" +SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations_v2" # noqa def check_data(): diff --git a/scripts/validation/.gitignore b/scripts/validation/.gitignore new file mode 100644 index 0000000..79ff55b --- /dev/null +++ b/scripts/validation/.gitignore @@ -0,0 +1,2 @@ +cache/ +results/ diff --git a/scripts/validation/IHCs/run_evaluation.py b/scripts/validation/IHCs/run_evaluation.py new file mode 100644 index 0000000..15c2afe --- /dev/null +++ b/scripts/validation/IHCs/run_evaluation.py @@ -0,0 +1,66 @@ +import os +from glob import glob + +import pandas as pd +from flamingo_tools.validation import ( + fetch_data_for_evaluation, parse_annotation_path, compute_scores_for_annotated_slice +) + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs" +ANNOTATION_FOLDERS = ["Annotations_LR"] + + +def run_evaluation(root, annotation_folders, result_file, cache_folder): + results = { + "annotator": [], + "cochlea": [], + "slice": [], + "tps": [], + "fps": [], + "fns": [], + } + + if cache_folder is not None: + os.makedirs(cache_folder, exist_ok=True) + + for folder in annotation_folders: + annotator = folder[len("Annotations"):] + annotations = sorted(glob(os.path.join(root, folder, "*.csv"))) + for annotation_path in annotations: + print(annotation_path) + cochlea, slice_id = parse_annotation_path(annotation_path) + + # For the cochlea M_LR_000226_R the actual component is 2, not 1 + component = 2 if "226_R" in cochlea else 1 + print("Run evaluation for", annotator, cochlea, "z=", slice_id) + segmentation, annotations = fetch_data_for_evaluation( + annotation_path, components_for_postprocessing=[component], + seg_name="IHC_v2", + cache_path=None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") + ) + scores = compute_scores_for_annotated_slice(segmentation, annotations, matching_tolerance=5) + results["annotator"].append(annotator) + results["cochlea"].append(cochlea) + results["slice"].append(slice_id) + results["tps"].append(scores["tp"]) + results["fps"].append(scores["fp"]) + results["fns"].append(scores["fn"]) + + table = pd.DataFrame(results) + table.to_csv(result_file, index=False) + print(table) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", default=ROOT) + parser.add_argument("--folders", default=ANNOTATION_FOLDERS) + parser.add_argument("--result_file", default="results.csv") + parser.add_argument("--cache_folder") + args = parser.parse_args() + run_evaluation(args.input, args.folders, args.result_file, args.cache_folder) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/IHCs/visualize_validation.py b/scripts/validation/IHCs/visualize_validation.py new file mode 100644 index 0000000..f33b96f --- /dev/null +++ b/scripts/validation/IHCs/visualize_validation.py @@ -0,0 +1,79 @@ +import argparse +import os +from glob import glob + +import napari +import tifffile + +from flamingo_tools.validation import ( + fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization, parse_annotation_path +) + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs" + + +def _match_image_path(annotation_path): + all_files = glob(os.path.join(ROOT, "*.tif")) + prefix = os.path.basename(annotation_path).split("_")[:-3] + prefix = "_".join(prefix) + matches = [path for path in all_files if os.path.basename(path).startswith(prefix)] + # if len(matches) != 1: + # breakpoint() + assert len(matches) == 1, f"{prefix}: {len(matches)}" + return matches[0] + + +def visualize_anotation(annotation_path, cache_folder): + print("Checking", annotation_path) + cochlea, slice_id = parse_annotation_path(annotation_path) + cache_path = None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") + + image_path = _match_image_path(annotation_path) + + component = 2 if "226_R" in cochlea else 1 + segmentation, annotations = fetch_data_for_evaluation( + annotation_path, cache_path=cache_path, components_for_postprocessing=[component], seg_name="IHC_v2", + ) + + image = tifffile.memmap(image_path) + if segmentation.ndim == 2: + image = image[image.shape[0] // 2] + assert image.shape == segmentation.shape, f"{image.shape}, {segmentation.shape}" + + matches = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance=5) + vis_segmentation, vis_points, seg_props, point_props = for_visualization(segmentation, annotations, matches) + + # tps, fns = matches["tp_annotations"], matches["fn"] + # print("True positive annotations:") + # print(tps) + # print("False negative annotations:") + # print(fns) + + v = napari.Viewer() + v.add_image(image) + v.add_labels(vis_segmentation, **seg_props) + v.add_points(vis_points, **point_props) + v.add_labels(segmentation, visible=False) + v.add_points(annotations, visible=False) + v.title = os.path.relpath(annotation_path, ROOT) + napari.run() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--annotations", nargs="+") + parser.add_argument("--cache_folder") + args = parser.parse_args() + cache_folder = args.cache_folder + + if args.annotations is None: + annotation_paths = sorted(glob(os.path.join(ROOT, "**", "*.csv"), recursive=True)) + else: + annotation_paths = args.annotations + + for annotation_path in annotation_paths: + visualize_anotation(annotation_path, cache_folder) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/SGNs/analyze.py b/scripts/validation/SGNs/analyze.py new file mode 100644 index 0000000..4a5ea94 --- /dev/null +++ b/scripts/validation/SGNs/analyze.py @@ -0,0 +1,43 @@ +import argparse +import pandas as pd + + +def compute_scores(table, annotator=None): + if annotator is None: + annotator = "all" + else: + table = table[table.annotator == annotator] + + tp = table.tps.sum() + fp = table.fps.sum() + fn = table.fns.sum() + + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1_score = 2 * precision * recall / (precision + recall) + + return pd.DataFrame({ + "annotator": [annotator], "precision": [precision], "recall": [recall], "f1-score": [f1_score] + }) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("result_file") + args = parser.parse_args() + + table = pd.read_csv(args.result_file) + annotators = pd.unique(table.annotator) + + results = [] + for annotator in annotators: + scores_annotator = compute_scores(table, annotator) + results.append(scores_annotator) + results.append(compute_scores(table, annotator=None)) + + results = pd.concat(results) + print(results) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/SGNs/compare_annotations.py b/scripts/validation/SGNs/compare_annotations.py new file mode 100644 index 0000000..2b2b4bc --- /dev/null +++ b/scripts/validation/SGNs/compare_annotations.py @@ -0,0 +1,59 @@ +import os +from glob import glob + +import napari +import pandas as pd +import tifffile + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" +ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationLR"] +COLOR = ["green", "yellow", "orange"] + + +def _match_annotations(image_path): + prefix = os.path.basename(image_path).split("_")[:3] + prefix = "_".join(prefix) + + annotations = {} + for annotation_folder in ANNOTATION_FOLDERS: + all_annotations = glob(os.path.join(ROOT, annotation_folder, "*.csv")) + matches = [ann for ann in all_annotations if os.path.basename(ann).startswith(prefix)] + if len(matches) != 1: + continue + annotation_path = matches[0] + + annotation = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values + annotations[annotation_folder] = annotation + + return annotations + + +def compare_annotations(image_path): + annotations = _match_annotations(image_path) + + image = tifffile.memmap(image_path) + v = napari.Viewer() + v.add_image(image) + for i, (name, annotation) in enumerate(annotations.items()): + v.add_points(annotation, name=name, face_color=COLOR[i]) + v.title = os.path.basename(image_path) + napari.run() + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--images", nargs="+") + args = parser.parse_args() + + if args.images is None: + image_paths = sorted(glob(os.path.join(ROOT, "*.tif"))) + else: + image_paths = args.images + + for image_path in image_paths: + compare_annotations(image_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/run_evaluation.py b/scripts/validation/SGNs/run_evaluation.py similarity index 90% rename from scripts/validation/run_evaluation.py rename to scripts/validation/SGNs/run_evaluation.py index 72cde9f..2153c24 100644 --- a/scripts/validation/run_evaluation.py +++ b/scripts/validation/SGNs/run_evaluation.py @@ -6,8 +6,8 @@ fetch_data_for_evaluation, parse_annotation_path, compute_scores_for_annotated_slice ) -ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation" -ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD"] +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" +ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationLR"] def run_evaluation(root, annotation_folders, result_file, cache_folder): @@ -27,12 +27,13 @@ def run_evaluation(root, annotation_folders, result_file, cache_folder): annotator = folder[len("Annotations"):] annotations = sorted(glob(os.path.join(root, folder, "*.csv"))) for annotation_path in annotations: + print(annotation_path) cochlea, slice_id = parse_annotation_path(annotation_path) # We don't have this cochlea in MoBIE yet if cochlea == "M_LR_000169_R": continue - print("Run evaluation for", annotator, cochlea, slice_id) + print("Run evaluation for", annotator, cochlea, "z=", slice_id) segmentation, annotations = fetch_data_for_evaluation( annotation_path, components_for_postprocessing=[1], cache_path=None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") diff --git a/scripts/validation/SGNs/visualize_validation.py b/scripts/validation/SGNs/visualize_validation.py new file mode 100644 index 0000000..2f8c23e --- /dev/null +++ b/scripts/validation/SGNs/visualize_validation.py @@ -0,0 +1,78 @@ +import argparse +import os +from glob import glob + +import napari +import tifffile + +from flamingo_tools.validation import ( + fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization, parse_annotation_path +) + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" + + +def _match_image_path(annotation_path): + all_files = glob(os.path.join(ROOT, "*.tif")) + prefix = os.path.basename(annotation_path).split("_")[:-3] + prefix = "_".join(prefix) + matches = [path for path in all_files if os.path.basename(path).startswith(prefix)] + # if len(matches) != 1: + # breakpoint() + assert len(matches) == 1, f"{prefix}: {len(matches)}" + return matches[0] + + +def visualize_anotation(annotation_path, cache_folder): + print("Checking", annotation_path) + cochlea, slice_id = parse_annotation_path(annotation_path) + cache_path = None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") + + image_path = _match_image_path(annotation_path) + + segmentation, annotations = fetch_data_for_evaluation( + annotation_path, cache_path=cache_path, components_for_postprocessing=[1], + ) + + image = tifffile.memmap(image_path) + if segmentation.ndim == 2: + image = image[image.shape[0] // 2] + assert image.shape == segmentation.shape, f"{image.shape}, {segmentation.shape}" + + matches = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance=5) + vis_segmentation, vis_points, seg_props, point_props = for_visualization(segmentation, annotations, matches) + + # tps, fns = matches["tp_annotations"], matches["fn"] + # print("True positive annotations:") + # print(tps) + # print("False negative annotations:") + # print(fns) + + v = napari.Viewer() + v.add_image(image) + v.add_labels(vis_segmentation, **seg_props) + v.add_points(vis_points, **point_props) + v.add_labels(segmentation, visible=False) + v.add_points(annotations, visible=False) + v.title = os.path.relpath(annotation_path, ROOT) + napari.run() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--annotations", nargs="+") + parser.add_argument("--cache_folder") + args = parser.parse_args() + cache_folder = args.cache_folder + + if args.annotations is None: + annotation_paths = sorted(glob(os.path.join(ROOT, "**", "*.csv"), recursive=True)) + else: + annotation_paths = args.annotations + + for annotation_path in annotation_paths: + visualize_anotation(annotation_path, cache_folder) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/analyze.py b/scripts/validation/analyze.py deleted file mode 100644 index e15a407..0000000 --- a/scripts/validation/analyze.py +++ /dev/null @@ -1,20 +0,0 @@ -import pandas as pd - -# TODO more logic to separate by annotator etc. -# For now this is just a simple script for global eval -table = pd.read_csv("./results.csv") -print("Table:") -print(table) - -tp = table.tps.sum() -fp = table.fps.sum() -fn = table.fns.sum() - -precision = tp / (tp + fp) -recall = tp / (tp + fn) -f1_score = 2 * precision * recall / (precision + recall) - -print("Evaluation:") -print("Precision:", precision) -print("Recall:", recall) -print("F1-Score:", f1_score) diff --git a/scripts/validation/check_annotations.py b/scripts/validation/check_annotations.py deleted file mode 100644 index 069a2fb..0000000 --- a/scripts/validation/check_annotations.py +++ /dev/null @@ -1,27 +0,0 @@ -import os - -import imageio.v3 as imageio -import napari -import pandas as pd - -# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation" -ROOT = "annotation_data" -TEST_ANNOTATION = os.path.join(ROOT, "AnnotationsEK/MAMD58L_PV_z771_base_full_annotationsEK.csv") - - -def check_annotation(image_path, annotation_path): - annotations = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values - - image = imageio.imread(image_path) - v = napari.Viewer() - v.add_image(image) - v.add_points(annotations) - napari.run() - - -def main(): - check_annotation(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"), TEST_ANNOTATION) - - -if __name__ == "__main__": - main() diff --git a/scripts/validation/check_nucleus_segmentation.py b/scripts/validation/check_nucleus_segmentation.py new file mode 100644 index 0000000..de1fc94 --- /dev/null +++ b/scripts/validation/check_nucleus_segmentation.py @@ -0,0 +1,44 @@ +import os + +import imageio.v3 as imageio +import numpy as np + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/226R_SGN_crop" +IMAGE_PATH = os.path.join(ROOT, "M_LR_000226_R_crop_0802-1067-0776_PV.tif") +SEG_PATH = os.path.join(ROOT, "M_LR_000226_R_crop_0802-1067-0776_SGN_v2.tif") +NUC_PATH = os.path.join(ROOT, "M_LR_000226_R_crop_0802-1067-0776_NUCLEI.tif") + + +def segment_nuclei(): + from flamingo_tools.segmentation.nucleus_segmentation import _naive_nucleus_segmentation_impl + + image = imageio.imread(IMAGE_PATH) + segmentation = imageio.imread(SEG_PATH) + + nuclei = np.zeros_like(segmentation, dtype=segmentation.dtype) + _naive_nucleus_segmentation_impl(image, segmentation, table=None, output=nuclei, n_threads=8, resolution=0.38) + + imageio.imwrite(NUC_PATH, nuclei, compression="zlib") + + +def check_segmentation(): + import napari + + image = imageio.imread(IMAGE_PATH) + segmentation = imageio.imread(SEG_PATH) + nuclei = imageio.imread(NUC_PATH) + + v = napari.Viewer() + v.add_image(image) + v.add_labels(segmentation) + v.add_labels(nuclei) + napari.run() + + +def main(): + segment_nuclei() + check_segmentation() + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/visualize_validation.py b/scripts/validation/visualize_validation.py deleted file mode 100644 index fcae5e3..0000000 --- a/scripts/validation/visualize_validation.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import os - -import imageio.v3 as imageio -import napari - -from flamingo_tools.validation import ( - fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization, parse_annotation_path -) - -# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation" -ROOT = "annotation_data" - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--image", required=True) - parser.add_argument("--annotation", required=True) - parser.add_argument("--cache_folder") - args = parser.parse_args() - cache_folder = args.cache_folder - - cochlea, slice_id = parse_annotation_path(args.annotation) - cache_path = None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") - - image = imageio.imread(args.image) - segmentation, annotations = fetch_data_for_evaluation( - args.annotation, cache_path=cache_path, components_for_postprocessing=[1], - ) - - matches = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance=5) - tps, fns = matches["tp_annotations"], matches["fn"] - vis_segmentation, vis_points, seg_props, point_props = for_visualization(segmentation, annotations, matches) - - print("True positive annotations:") - print(tps) - print("False negative annotations:") - print(fns) - - v = napari.Viewer() - v.add_image(image) - v.add_labels(vis_segmentation, **seg_props) - v.add_points(vis_points, **point_props) - v.add_labels(segmentation, visible=False) - v.add_points(annotations, visible=False) - napari.run() - - -if __name__ == "__main__": - main()