Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions flamingo_tools/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import os
import re
from typing import Dict, List, Optional, Tuple

import imageio.v3 as imageio
import numpy as np
import pandas as pd
import zarr

from scipy.ndimage import distance_transform_edt
from scipy.optimize import linear_sum_assignment
from skimage.measure import regionprops_table
from skimage.segmentation import relabel_sequential
from tqdm import tqdm

from .s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT


def _normalize_cochlea_name(name):
match = re.search(r"\d+", name)
pos = match.start() if match else None
assert pos is not None, name
prefix = name[:pos]
prefix = f"{prefix[0]}_{prefix[1:]}"
number = int(name[pos:-1])
postfix = name[-1]
return f"{prefix}_{number:06d}_{postfix}"


def parse_annotation_path(annotation_path):
fname = os.path.basename(annotation_path)
name_parts = fname.split("_")
cochlea = _normalize_cochlea_name(name_parts[0])
slice_id = int(name_parts[2][1:])
return cochlea, slice_id


# TODO enable table component filtering with MoBIE table
# NOTE: the main component is always #1
def fetch_data_for_evaluation(
annotation_path: str,
cache_path: Optional[str] = None,
seg_name: str = "SGN",
z_extent: int = 0,
components_for_postprocessing: Optional[List[int]] = None,
) -> Tuple[np.ndarray, pd.DataFrame]:
"""
"""
# Load the annotations and normalize them for the given z-extent.
annotations = pd.read_csv(annotation_path)
annotations = annotations.drop(columns="index")
if z_extent == 0: # If we don't have a z-extent then we just drop the first axis and rename the other two.
annotations = annotations.drop(columns="axis-0")
annotations = annotations.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
else: # Otherwise we have to center the first axis.
# TODO
raise NotImplementedError

# Load the segmentaiton from cache path if it is given and if it is already cached.
if cache_path is not None and os.path.exists(cache_path):
segmentation = imageio.imread(cache_path)
return segmentation, annotations

# Parse which ID and which cochlea from the name.
cochlea, slice_id = parse_annotation_path(annotation_path)

# Open the S3 connection, get the path to the SGN segmentation in S3.
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr")
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)

# Compute the roi for the given z-extent.
if z_extent == 0:
roi = slice_id
else:
roi = slice(slice_id - z_extent, slice_id + z_extent)

# Download the segmentation for this slice and the given z-extent.
input_key = "s0"
with zarr.open(s3_store, mode="r") as f:
segmentation = f[input_key][roi]

if components_for_postprocessing is not None:
# Filter the IDs so that only the ones part of 'components_for_postprocessing_remain'.

# First, we download the MoBIE table for this segmentation.
internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv")
with fs.open(internal_path, "r") as f:
table = pd.read_csv(f, sep="\t")

# Then we get the ids for the components and us them to filter the segmentation.
component_mask = np.isin(table.component_labels.values, components_for_postprocessing)
keep_label_ids = table.label_id.values[component_mask].astype("int64")
filter_mask = ~np.isin(segmentation, keep_label_ids)
segmentation[filter_mask] = 0

segmentation, _, _ = relabel_sequential(segmentation)

# Cache it if required.
if cache_path is not None:
imageio.imwrite(cache_path, segmentation, compression="zlib")

return segmentation, annotations


# We should use the hungarian based matching, but I can't find the bug in it right now.
def _naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
distances, indices = distance_transform_edt(segmentation == 0, return_indices=True)

matched_ids = {}
matched_distances = {}
annotation_id = 0
for _, row in annotations.iterrows():
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
object_distance = distances[coordinate]
if object_distance <= matching_tolerance:
closest_object_coord = tuple(idx[coordinate] for idx in indices)
object_id = segmentation[closest_object_coord]
if object_id not in matched_ids or matched_distances[object_id] > object_distance:
matched_ids[object_id] = annotation_id
matched_distances[object_id] = object_distance
annotation_id += 1

tp_ids_objects = np.array(list(matched_ids.keys()))
tp_ids_annotations = np.array(list(matched_ids.values()))
return tp_ids_objects, tp_ids_annotations


# There is a bug in here that neither I nor o3 can figure out ...
def _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates):
n_objects, n_annotations = len(segmentation_ids), len(annotations)

# In order to get the full distance matrix, we compute the distance to all objects for each annotation.
# This is not very efficient, but it's the most straight-forward and most rigorous approach.
scores = np.zeros((n_objects, n_annotations), dtype="float")
i = 0
for _, row in tqdm(annotations.iterrows(), total=n_annotations, desc="Compute pairwise distances"):
coordinate = tuple(int(np.round(row[coord])) for coord in coordinates)
distance_input = np.ones(segmentation.shape, dtype="bool")
distance_input[coordinate] = False
distances = distance_transform_edt(distance_input)

props = regionprops_table(segmentation, intensity_image=distances, properties=("label", "min_intensity"))
distances = props["min_intensity"]
assert len(distances) == scores.shape[0]
scores[:, i] = distances
i += 1

# Find the assignment of points to objects.
# These correspond to the TP ids in the point / object annotations.
tp_ids_objects, tp_ids_annotations = linear_sum_assignment(scores)
match_ok = scores[tp_ids_objects, tp_ids_annotations] <= matching_tolerance
tp_ids_objects, tp_ids_annotations = tp_ids_objects[match_ok], tp_ids_annotations[match_ok]
tp_ids_objects = segmentation_ids[tp_ids_objects]

return tp_ids_objects, tp_ids_annotations


def compute_matches_for_annotated_slice(
segmentation: np.typing.ArrayLike,
annotations: pd.DataFrame,
matching_tolerance: float = 0.0,
) -> Dict[str, np.ndarray]:
"""Computes the ids of matches and non-matches for a annotated validation slice.
Computes true positive ids (for objects and annotations), false positive ids and false negative ids
by solving a linear cost assignment of distances between objects and annotations.
Args:
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
annotations: The annotations, marking cell centers.
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
Returns:
A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
"""
assert segmentation.ndim in (2, 3)
coordinates = ["axis-0", "axis-1"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
segmentation_ids = np.unique(segmentation)[1:]

# Crop to the minimal enclosing bounding box of points and segmented objects.
bb_seg = np.where(segmentation != 0)
bb_seg = tuple(slice(int(bb.min()), int(bb.max())) for bb in bb_seg)
bb_points = tuple(
slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1)
for coords in coordinates
)
bbox = tuple(slice(min(bbs.start, bbp.start), max(bbs.stop, bbp.stop)) for bbs, bbp in zip(bb_seg, bb_points))
segmentation = segmentation[bbox]

annotations = annotations.copy()
for coord, bb in zip(coordinates, bbox):
annotations[coord] -= bb.start
assert (annotations[coord] <= bb.stop).all()

# tp_ids_objects, tp_ids_annotations =\
# _assignment_based_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
tp_ids_objects, tp_ids_annotations =\
_naive_matching(annotations, segmentation, segmentation_ids, matching_tolerance, coordinates)
assert len(tp_ids_objects) == len(tp_ids_annotations)

# Find the false positives: objects that are not part of the matches.
fp_ids = np.setdiff1d(segmentation_ids, tp_ids_objects)

# Find the false negatives: annotations that are not part of the matches.
fn_ids = np.setdiff1d(np.arange(len(annotations)), tp_ids_annotations)

return {"tp_objects": tp_ids_objects, "tp_annotations": tp_ids_annotations, "fp": fp_ids, "fn": fn_ids}


def compute_scores_for_annotated_slice(
segmentation: np.typing.ArrayLike,
annotations: pd.DataFrame,
matching_tolerance: float = 0.0,
) -> Dict[str, int]:
"""Computes the scores for a annotated validation slice.
Computes true positives, false positives and false negatives for scoring.
Args:
segmentation: The segmentation for this slide. We assume that it is relabeled consecutively.
annotations: The annotations, marking cell centers.
matching_tolerance: The maximum distance for matching an annotation to a segmented object.
Returns:
A dictionary with keys 'tp', 'fp' and 'fn', mapping to the respective counts.
"""
result = compute_matches_for_annotated_slice(segmentation, annotations, matching_tolerance)

# To determine the TPs, FPs and FNs.
tp = len(result["tp_objects"])
fp = len(result["fp"])
fn = len(result["fn"])
return {"tp": tp, "fp": fp, "fn": fn}


def for_visualization(segmentation, annotations, matches):
green_red = ["#00FF00", "#FF0000"]

seg_vis = np.zeros_like(segmentation)
tps, fps = matches["tp_objects"], matches["fp"]
seg_vis[np.isin(segmentation, tps)] = 1
seg_vis[np.isin(segmentation, fps)] = 2

seg_props = dict(colormap={1: green_red[0], 2: green_red[1]})

point_vis = annotations.copy()
tps = matches["tp_annotations"]
match_properties = ["tp" if aid in tps else "fn" for aid in range(len(annotations))]
# The color cycle assigns the first color to the first property etc.
# So we need to set the first color to red if the first id is a false negative and vice versa.
color_cycle = green_red[::-1] if match_properties[0] == "fn" else green_red
point_props = dict(
properties={
"id": list(range(len(annotations))),
"match": match_properties,
},
face_color="match",
face_color_cycle=color_cycle,
border_width=0.25,
size=10,
)

return seg_vis, point_vis, seg_props, point_props
20 changes: 20 additions & 0 deletions scripts/validation/analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd

# TODO more logic to separate by annotator etc.
# For now this is just a simple script for global eval
table = pd.read_csv("./results.csv")
print("Table:")
print(table)

tp = table.tps.sum()
fp = table.fps.sum()
fn = table.fns.sum()

precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1_score = 2 * precision * recall / (precision + recall)

print("Evaluation:")
print("Precision:", precision)
print("Recall:", recall)
print("F1-Score:", f1_score)
27 changes: 27 additions & 0 deletions scripts/validation/check_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import imageio.v3 as imageio
import napari
import pandas as pd

# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation"
ROOT = "annotation_data"
TEST_ANNOTATION = os.path.join(ROOT, "AnnotationsEK/MAMD58L_PV_z771_base_full_annotationsEK.csv")


def check_annotation(image_path, annotation_path):
annotations = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values

image = imageio.imread(image_path)
v = napari.Viewer()
v.add_image(image)
v.add_points(annotations)
napari.run()


def main():
check_annotation(os.path.join(ROOT, "MAMD58L_PV_z771_base_full.tif"), TEST_ANNOTATION)


if __name__ == "__main__":
main()
65 changes: 65 additions & 0 deletions scripts/validation/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from glob import glob

import pandas as pd
from flamingo_tools.validation import (
fetch_data_for_evaluation, parse_annotation_path, compute_scores_for_annotated_slice
)

ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1Validation"
ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD"]


def run_evaluation(root, annotation_folders, result_file, cache_folder):
results = {
"annotator": [],
"cochlea": [],
"slice": [],
"tps": [],
"fps": [],
"fns": [],
}

if cache_folder is not None:
os.makedirs(cache_folder, exist_ok=True)

for folder in annotation_folders:
annotator = folder[len("Annotations"):]
annotations = sorted(glob(os.path.join(root, folder, "*.csv")))
for annotation_path in annotations:
cochlea, slice_id = parse_annotation_path(annotation_path)
# We don't have this cochlea in MoBIE yet
if cochlea == "M_LR_000169_R":
continue

print("Run evaluation for", annotator, cochlea, slice_id)
segmentation, annotations = fetch_data_for_evaluation(
annotation_path, components_for_postprocessing=[1],
cache_path=None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif")
)
scores = compute_scores_for_annotated_slice(segmentation, annotations, matching_tolerance=5)
results["annotator"].append(annotator)
results["cochlea"].append(cochlea)
results["slice"].append(slice_id)
results["tps"].append(scores["tp"])
results["fps"].append(scores["fp"])
results["fns"].append(scores["fn"])

table = pd.DataFrame(results)
table.to_csv(result_file, index=False)
print(table)


def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", default=ROOT)
parser.add_argument("--folders", default=ANNOTATION_FOLDERS)
parser.add_argument("--result_file", default="results.csv")
parser.add_argument("--cache_folder")
args = parser.parse_args()
run_evaluation(args.input, args.folders, args.result_file, args.cache_folder)


if __name__ == "__main__":
main()
Loading