diff --git a/flamingo_tools/segmentation/synapse_detection.py b/flamingo_tools/segmentation/synapse_detection.py new file mode 100644 index 0000000..ef27db1 --- /dev/null +++ b/flamingo_tools/segmentation/synapse_detection.py @@ -0,0 +1,220 @@ +import os +from typing import Optional, Tuple + +import numpy as np +import pandas as pd +import zarr +from scipy.ndimage import binary_dilation + +from elf.parallel.local_maxima import find_local_maxima +from elf.parallel.distance_transform import map_points_to_objects +from flamingo_tools.file_utils import read_image_data +from flamingo_tools.segmentation.unet_prediction import prediction_impl + + +def map_and_filter_detections( + segmentation: np.ndarray, + detections: pd.DataFrame, + max_distance: float, + resolution: float = 0.38, + n_threads: Optional[int] = None, + verbose: bool = True, +) -> pd.DataFrame: + """Map synapse detections to segmented IHCs and filter out detections above a distance threshold to the IHCs. + + Args: + segmentation: The IHC segmentation. + detections: The synapse marker detections. + max_distance: The maximal distance in micrometer for a valid match of synapse markers to IHCs. + resolution: The resolution / voxel size of the data in micrometer. + n_threads: The number of threads for parallelizing the mapping of detections to objects. + verbose: Whether to print the progress of the mapping procedure. + + Returns: + The filtered dataframe with the detections mapped to the segmentation. + """ + # Get the point coordinates. + points = detections[["z", "y", "x"]].values.astype("int") + + # Set the block shape (this could also be exposed as a parameter; it should not matter much though). + block_shape = (64, 256, 256) + + # Determine the halo. We set it to 2 pixels + the max-distance in pixels, to ensure all distances + # that are smaller than the max distance are measured. + halo = (2 + int(np.ceil(max_distance / resolution)),) * 3 + + # Map the detections to the obejcts in the (IHC) segmentation. + object_ids, object_distances = map_points_to_objects( + segmentation=segmentation, + points=points, + block_shape=block_shape, + halo=halo, + sampling=resolution, + n_threads=n_threads, + verbose=verbose, + ) + assert len(object_ids) == len(points) + assert len(object_distances) == len(points) + + # Add matched ids and distances to the dataframe. + detections["matched_ihc"] = object_ids + detections["distance_to_ihc"] = object_distances + + # Filter the dataframe by the max distance. + detections = detections[detections.distance_to_ihc < max_distance] + return detections + + +def run_prediction( + input_path: str, + input_key: str, + output_folder: str, + model_path: str, + block_shape: Optional[Tuple[int, int, int]] = None, + halo: Optional[Tuple[int, int, int]] = None, +): + """Run prediction for synapse detection. + + Args: + input_path: Input path to image channel for synapse detection. + input_key: Input key for resolution of image channel and mask channel. + output_folder: Output folder for synapse segmentation and marker detection. + model_path: Path to model for synapse detection. + block_shape: The block-shape for running the prediction. + halo: The halo (= block overlap) to use for prediction. + """ + if block_shape is None: + block_shape = (64, 256, 256) + if halo is None: + halo = (16, 64, 64) + + # Skip existing prediction, which is saved in output_folder/predictions.zarr + skip_prediction = False + output_path = os.path.join(output_folder, "predictions.zarr") + prediction_key = "prediction" + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): + skip_prediction = True + + if not skip_prediction: + prediction_impl( + input_path, input_key, output_folder, model_path, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + detection_path = os.path.join(output_folder, "synapse_detection.tsv") + if not os.path.exists(detection_path): + input_ = zarr.open(output_path, "r")[prediction_key] + detections = find_local_maxima( + input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16, + ) + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + detections.to_csv(detection_path, index=False, sep="\t") + + +def marker_detection( + input_path: str, + input_key: str, + mask_path: str, + output_folder: str, + model_path: str, + mask_input_key: str = "s4", + max_distance: float = 20, + resolution: float = 0.38, +): + """Streamlined workflow for marker detection, mapping, and filtering. + + Args: + input_path: Input path to image channel for synapse detection. + input_key: Input key for resolution of image channel and mask channel. + mask_path: Path to IHC segmentation used to mask input. + output_folder: Output folder for synapse segmentation and marker detection. + model_path: Path to model for synapse detection. + mask_input_key: Key to undersampled IHC segmentation for masking input for synapse detection. + max_distance: The maximal distance for a valid match of synapse markers to IHCs. + resolution: The resolution / voxel size of the data in micrometer. + """ + + # 1.) Determine mask for inference based on the IHC segmentation. + # Best approach: load IHC segmentation at a low scale level, binarize it, + # dilate it and use this as mask. It can be mapped back to the full resolution + # with `elf.wrapper.ResizedVolume`. + + skip_masking = False + + mask_preprocess_key = "mask" + output_file = os.path.join(output_folder, "mask.zarr") + + if os.path.exists(output_file) and mask_preprocess_key in zarr.open(output_file, "r"): + skip_masking = True + + if not skip_masking: + mask_ = read_image_data(mask_path, mask_input_key) + new_mask = np.zeros(mask_.shape) + new_mask[mask_ != 0] = 1 + arr_bin = binary_dilation(mask_, structure=np.ones((9, 9, 9))).astype(int) + + with zarr.open(output_file, mode="w") as f_out: + f_out.create_dataset(mask_preprocess_key, data=arr_bin, compression="gzip") + + # 2.) Run inference and detection of maxima. + # This can be taken from 'scripts/synapse_marker_detection/run_prediction.py' + # (And the run prediction script should then be refactored). + + block_shape = (64, 256, 256) + halo = (16, 64, 64) + + # Skip existing prediction, which is saved in output_folder/predictions.zarr + skip_prediction = False + output_path = os.path.join(output_folder, "predictions.zarr") + prediction_key = "prediction" + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): + skip_prediction = True + + if not skip_prediction: + prediction_impl( + input_path, input_key, output_folder, model_path, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + detection_path = os.path.join(output_folder, "synapse_detection.tsv") + if not os.path.exists(detection_path): + input_ = zarr.open(output_path, "r")[prediction_key] + detections = find_local_maxima( + input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16, + ) + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + detections.to_csv(detection_path, index=False, sep="\t") + + else: + with open(detection_path, 'r') as f: + detections = pd.read_csv(f, sep="\t") + + # 3.) Map the detections to IHC and filter them based on a distance criterion. + # Use the function 'map_and_filter_detections' from above. + input_ = read_image_data(mask_path, input_key) + + detections_filtered = map_and_filter_detections( + segmentation=input_, + detections=detections, + max_distance=max_distance, + resolution=resolution, + ) + + # 4.) Add the filtered detections to MoBIE. + # IMPORTANT scale the coordinates with the resolution here. + detections_filtered["distance_to_ihc"] *= resolution + detections_filtered["x"] *= resolution + detections_filtered["y"] *= resolution + detections_filtered["z"] *= resolution + detection_path = os.path.join(output_folder, "synapse_detection_filtered.tsv") + detections_filtered.to_csv(detection_path, index=False, sep="\t") diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 99a0048..38f66e6 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -77,15 +77,21 @@ def prediction_impl( else: model = torch.load(model_path, weights_only=False) + input_ = read_image_data(input_path, input_key) + chunks = getattr(input_, "chunks", (64, 64, 64)) mask_path = os.path.join(output_folder, "mask.zarr") + if os.path.exists(mask_path): image_mask = z5py.File(mask_path, "r")["mask"] + # resize mask + image_shape = input_.shape + mask_shape = image_mask.shape + if image_shape != mask_shape: + image_mask = ResizedVolume(image_mask, image_shape, order=0) + else: image_mask = None - input_ = read_image_data(input_path, input_key) - chunks = getattr(input_, "chunks", (64, 64, 64)) - if scale is None or np.isclose(scale, 1): original_shape = None else: diff --git a/flamingo_tools/validation.py b/flamingo_tools/validation.py index 2165d7a..87494ad 100644 --- a/flamingo_tools/validation.py +++ b/flamingo_tools/validation.py @@ -1,5 +1,6 @@ import os import re +from collections import defaultdict from typing import Dict, List, Optional, Tuple import imageio.v3 as imageio @@ -9,6 +10,7 @@ from scipy.ndimage import distance_transform_edt from scipy.optimize import linear_sum_assignment +from scipy.spatial import cKDTree from skimage.measure import regionprops_table from skimage.segmentation import relabel_sequential from tqdm import tqdm @@ -27,7 +29,7 @@ def _normalize_cochlea_name(name): return f"{prefix}_{number:06d}_{postfix}" -def parse_annotation_path(annotation_path): +def _parse_annotation_path(annotation_path): fname = os.path.basename(annotation_path) name_parts = fname.split("_") cochlea = _normalize_cochlea_name(name_parts[0]) @@ -42,7 +44,19 @@ def fetch_data_for_evaluation( z_extent: int = 0, components_for_postprocessing: Optional[List[int]] = None, ) -> Tuple[np.ndarray, pd.DataFrame]: - """ + """Fetch segmentation from S3 matching the annotation path for evaluation. + + Args: + annotation_path: The path to the manual annotations. + cache_path: An optional path for caching the downloaded segmentation. + seg_name: The name of the segmentation in the bucket. + z_extent: Additional z-slices to load from the segmentation. + components_for_postprocessing: The component ids for restricting the segmentation to. + Choose [1] for the default componentn containing the helix. + + Returns: + The segmentation downloaded from the S3 bucket. + The annotations loaded from pandas and matching the segmentation. """ # Load the annotations and normalize them for the given z-extent. annotations = pd.read_csv(annotation_path) @@ -60,7 +74,7 @@ def fetch_data_for_evaluation( return segmentation, annotations # Parse which ID and which cochlea from the name. - cochlea, slice_id = parse_annotation_path(annotation_path) + cochlea, slice_id = _parse_annotation_path(annotation_path) # Open the S3 connection, get the path to the SGN segmentation in S3. internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr") @@ -176,13 +190,21 @@ def compute_matches_for_annotated_slice( segmentation_ids = np.unique(segmentation)[1:] # Crop to the minimal enclosing bounding box of points and segmented objects. - bb_seg = np.where(segmentation != 0) - bb_seg = tuple(slice(int(bb.min()), int(bb.max())) for bb in bb_seg) - bb_points = tuple( - slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1) - for coords in coordinates - ) - bbox = tuple(slice(min(bbs.start, bbp.start), max(bbs.stop, bbp.stop)) for bbs, bbp in zip(bb_seg, bb_points)) + seg_mask = segmentation != 0 + if seg_mask.sum() > 0: + bb_seg = np.where(seg_mask) + bb_seg = tuple(slice(int(bb.min()), int(bb.max())) for bb in bb_seg) + bb_points = tuple( + slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1) + for coords in coordinates + ) + bbox = tuple(slice(min(bbs.start, bbp.start), max(bbs.stop, bbp.stop)) for bbs, bbp in zip(bb_seg, bb_points)) + else: + print("The segmentation is empty!!!") + bbox = tuple( + slice(int(np.floor(annotations[coords].min())), int(np.ceil(annotations[coords].max())) + 1) + for coords in coordinates + ) segmentation = segmentation[bbox] annotations = annotations.copy() @@ -231,6 +253,100 @@ def compute_scores_for_annotated_slice( return {"tp": tp, "fp": fp, "fn": fn} +def create_consensus_annotations( + annotation_paths: Dict[str, str], + matching_distance: float = 5.0, + min_matches_for_consensus: int = 2, +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Create a consensus annotation from multiple manual annotations. + + Args: + annotation_paths: A dictionary that maps annotator names to the path to the manual annotations. + matching_distance: The maximum distance for matching annotations to a consensus annotation. + min_matches_for_consensus: The minimum number of matching annotations to consider an annotation as consensus. + + Returns: + A dataframe with the consensus annotations. + A dataframe with the unmatched annotations. + """ + dfs, coords, ann_id = [], [], [] + for name, path in annotation_paths.items(): + df = pd.read_csv(path, usecols=["axis-0", "axis-1", "axis-2"]) + df["annotator"] = name + dfs.append(df) + big = pd.concat(dfs, ignore_index=True) + coords = big[["axis-0", "axis-1", "axis-2"]].values + ann_id = big["annotator"].values + + trees, idx_by_ann = {}, {} + for ann in np.unique(ann_id): + idx = np.where(ann_id == ann)[0] + idx_by_ann[ann] = idx + trees[ann] = cKDTree(coords[idx]) + + edges = [] + for i, annA in enumerate(trees): + idxA, treeA = idx_by_ann[annA], trees[annA] + for annB in list(trees)[i+1:]: + idxB, treeB = idx_by_ann[annB], trees[annB] + + # A -> B + dAB, jB = treeB.query(coords[idxA], distance_upper_bound=matching_distance) + # B -> A + dBA, jA = treeA.query(coords[idxB], distance_upper_bound=matching_distance) + + for k, (d, j) in enumerate(zip(dAB, jB)): + if np.isfinite(d): + a_idx = idxA[k] + b_idx = idxB[j] + # check reciprocity + if jA[j] == k and np.isfinite(dBA[j]): + edges.append((a_idx, b_idx)) + + # --- union–find to group --------------------------------- + parent = np.arange(len(coords)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a, b): + ra, rb = find(a), find(b) + if ra != rb: + parent[rb] = ra + + for a, b in edges: + union(a, b) + + # --- collect results ------------------------------------- + cluster = defaultdict(list) + for i in range(len(coords)): + cluster[find(i)].append(i) + + consensus_rows, unmatched = [], [] + for members in cluster.values(): + if len(members) >= min_matches_for_consensus: + anns = {ann_id[m] for m in members} + # by construction anns are unique + subset = coords[members] + rep_pt = subset.mean(0) + consensus_rows.append({ + "axis-0": rep_pt[0], + "axis-1": rep_pt[1], + "axis-2": rep_pt[2], + "annotators": anns, + "member_indices": members + }) + else: + unmatched.extend(members) + + consensus_df = pd.DataFrame(consensus_rows) + unmatched_df = big.iloc[unmatched].reset_index(drop=True) + return consensus_df, unmatched_df + + def for_visualization(segmentation, annotations, matches): green_red = ["#00FF00", "#FF0000"] diff --git a/scripts/export_lower_resolution.py b/scripts/export_lower_resolution.py new file mode 100644 index 0000000..fdde906 --- /dev/null +++ b/scripts/export_lower_resolution.py @@ -0,0 +1,63 @@ +import argparse +import os + +import numpy as np +import pandas as pd +import tifffile +import zarr + +from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT +from skimage.segmentation import relabel_sequential + + +def filter_component(fs, segmentation, cochlea, seg_name): + # First, we download the MoBIE table for this segmentation. + internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv") + with fs.open(internal_path, "r") as f: + table = pd.read_csv(f, sep="\t") + + # Then we get the ids for the components and us them to filter the segmentation. + component_mask = np.isin(table.component_labels.values, [1]) + keep_label_ids = table.label_id.values[component_mask].astype("int64") + filter_mask = ~np.isin(segmentation, keep_label_ids) + segmentation[filter_mask] = 0 + + segmentation, _, _ = relabel_sequential(segmentation) + return segmentation + + +def export_lower_resolution(args): + output_folder = os.path.join(args.output_folder, args.cochlea, f"scale{args.scale}") + os.makedirs(output_folder, exist_ok=True) + + input_key = f"s{args.scale}" + for channel in args.channels: + out_path = os.path.join(output_folder, f"{channel}.tif") + if os.path.exists(out_path): + continue + + print("Exporting channel", channel) + internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{channel}.ome.zarr") + s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT) + with zarr.open(s3_store, mode="r") as f: + data = f[input_key][:] + print(data.shape) + if args.filter_by_component: + data = filter_component(fs, data, args.cochlea, channel) + tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--cochlea", "-c", required=True) + parser.add_argument("--scale", "-s", type=int, required=True) + parser.add_argument("--output_folder", "-o", required=True) + parser.add_argument("--channels", nargs="+", default=["PV", "VGlut3", "CTBP2"]) + parser.add_argument("--filter_by_component", action="store_true") + args = parser.parse_args() + + export_lower_resolution(args) + + +if __name__ == "__main__": + main() diff --git a/scripts/ihc_analysis/synapse_mapping.py b/scripts/ihc_analysis/synapse_mapping.py new file mode 100644 index 0000000..8ad6c89 --- /dev/null +++ b/scripts/ihc_analysis/synapse_mapping.py @@ -0,0 +1,57 @@ +import os + +import imageio.v3 as imageio +import pandas as pd + +from flamingo_tools.segmentation.marker_detection import map_and_filter_detections + + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/croppings/Synapse_crop" +VGLUT_PATH = os.path.join(ROOT, "M_LR_000226_R_crop_1098-0926-0872_Vglut3.tif") +CTBP2_PATH = os.path.join(ROOT, "M_LR_000226_R_crop_1098-0926-0872_CTBP2.tif") +SEG_PATH = os.path.join(ROOT, "M_LR_000226_R_resized_crop_1098-0926-0872_IHC.tif") +DET_PATH = os.path.join(ROOT, "synapses/synapse_detection.tsv") + + +def check_data(): + import napari + + detections = pd.read_csv(DET_PATH, sep="\t") + detections = detections[["z", "y", "x"]].values + + filtered_path = os.path.join(ROOT, "synapses/synapse_detection_filtered.tsv") + filtered_detections = pd.read_csv(filtered_path, sep="\t") + filtered_detections = filtered_detections[["z", "y", "x"]].values + + vglut = imageio.imread(VGLUT_PATH) + ctbp2 = imageio.imread(CTBP2_PATH) + ihcs = imageio.imread(SEG_PATH) + + v = napari.Viewer() + v.add_image(vglut) + v.add_image(ctbp2) + v.add_labels(ihcs) + v.add_points(detections) + v.add_points(filtered_detections) + napari.run() + + +def map_synapses(): + ihcs = imageio.imread(SEG_PATH) + detections = pd.read_csv(DET_PATH, sep="\t") + n_detections = len(detections) + + detections = map_and_filter_detections(ihcs, detections, max_distance=2.0, n_threads=8) + print("Detections after mapping and fitering:", len(detections), "/", n_detections) + + out_path = os.path.join(ROOT, "synapses/synapse_detection_filtered.tsv") + detections.to_csv(out_path, sep="\t", index=False) + + +def main(): + map_synapses() + # check_data() + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py index b194bb1..ae9b361 100644 --- a/scripts/synapse_marker_detection/detection_dataset.py +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -7,22 +7,42 @@ from torch_em.util import ensure_tensor_with_channels -# Process labels stored in json napari style. -# I don't actually think that we need the epsilon here, but will leave it for now. -def process_labels(label_path, shape, sigma, eps, bb=None): - points = pd.read_csv(label_path) +class MinPointSampler: + """A sampler to reject samples with a low fraction of foreground pixels in the labels. + + Args: + min_fraction: The minimal fraction of foreground pixels for accepting a sample. + background_id: The id of the background label. + p_reject: The probability for rejecting a sample that does not meet the criterion. + """ + def __init__(self, min_points: int, p_reject: float = 1.0): + self.min_points = min_points + self.p_reject = p_reject + + def __call__(self, x: np.ndarray, n_points: int) -> bool: + """Check the sample. + + Args: + x: The raw data. + y: The label data. + + Returns: + Whether to accept this sample. + """ + + if n_points > self.min_points: + return True + else: + return np.random.rand() > self.p_reject - if bb: - (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] - restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) - labels = np.zeros(restricted_shape, dtype="float32") - shape = restricted_shape - else: - labels = np.zeros(shape, dtype="float32") +def load_labels(label_path, shape, bb): + points = pd.read_csv(label_path) assert len(points.columns) == len(shape) - z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"] + z_coords, y_coords, x_coords = points["axis-0"].values, points["axis-1"].values, points["axis-2"].values + if bb is not None: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] z_coords -= z_min y_coords -= y_min x_coords -= x_min @@ -32,13 +52,31 @@ def process_labels(label_path, shape, sigma, eps, bb=None): np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)), ]) z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask] + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) + shape = restricted_shape + n_points = len(z_coords) coords = tuple( np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip( (z_coords, y_coords, x_coords), shape ) ) + return coords, n_points + + +# Process labels stored in json napari style. +# I don't actually think that we need the epsilon here, but will leave it for now. +def process_labels(coords, shape, sigma, eps, bb=None): + + if bb: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) + labels = np.zeros(restricted_shape, dtype="float32") + shape = restricted_shape + else: + labels = np.zeros(shape, dtype="float32") + labels[coords] = 1 labels = gaussian(labels, sigma) # TODO better normalization? @@ -124,16 +162,10 @@ def _get_sample(self, index): raw, label_path = self.raw_path, self.label_path raw = zarr.open(raw)[self.raw_key] + have_raw_channels = raw.ndim == 4 # 3D with channels shape = raw.shape bb = self._sample_bounding_box(shape) - label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb) - - have_raw_channels = raw.ndim == 4 # 3D with channels - have_label_channels = label.ndim == 4 - if have_label_channels: - raise NotImplementedError("Multi-channel labels are not supported.") - prefix_box = tuple() if have_raw_channels: if shape[-1] < 16: @@ -143,18 +175,25 @@ def _get_sample(self, index): prefix_box = (slice(None), ) raw_patch = np.array(raw[prefix_box + bb]) - label_patch = np.array(label) + coords, n_points = load_labels(label_path, shape, bb) if self.sampler is not None: - assert False, "Sampler not implemented" - # sample_id = 0 - # while not self.sampler(raw_patch, label_patch): - # bb = self._sample_bounding_box(shape) - # raw_patch = np.array(raw[prefix_box + bb]) - # label_patch = np.array(label[bb]) - # sample_id += 1 - # if sample_id > self.max_sampling_attempts: - # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + sample_id = 0 + while not self.sampler(raw_patch, n_points): + bb = self._sample_bounding_box(shape) + raw_patch = np.array(raw[prefix_box + bb]) + coords, n_points = load_labels(label_path, shape, bb) + sample_id += 1 + if sample_id > self.max_sampling_attempts: + raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + + label = process_labels(coords, shape, self.sigma, self.eps, bb=bb) + + have_label_channels = label.ndim == 4 + if have_label_channels: + raise NotImplementedError("Multi-channel labels are not supported.") + + label_patch = np.array(label) if have_raw_channels and len(prefix_box) == 0: raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width diff --git a/scripts/synapse_marker_detection/extract_training_data.py b/scripts/synapse_marker_detection/extract_training_data.py index 3017577..68090f7 100644 --- a/scripts/synapse_marker_detection/extract_training_data.py +++ b/scripts/synapse_marker_detection/extract_training_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import imageio.v3 as imageio import napari import numpy as np import pandas as pd @@ -19,27 +20,61 @@ def get_voxel_size(imaris_file): return vsize -def extract_training_data(imaris_file, output_folder): +def get_transformation(imaris_file): + with h5py.File(imaris_file) as f: + info = f["DataSetInfo"]["Image"].attrs + ext_min = np.array([float(b"".join(info[f"ExtMin{i}"]).decode()) for i in range(3)]) + ext_max = np.array([float(b"".join(info[f"ExtMax{i}"]).decode()) for i in range(3)]) + size = [int(b"".join(info[dim]).decode()) for dim in ["X", "Y", "Z"]] + spacing = (ext_max - ext_min) / size # µm / voxel + + # build 4×4 affine: world → index + T = np.eye(4) + T[:3, :3] = np.diag(1/spacing) # scale + T[:3, 3] = -ext_min/spacing # translate + + return T + + +def extract_training_data(imaris_file, output_folder, tif_file=None, crop=True): + point_key = "/Scene/Content/Points0/CoordsXYZR" with h5py.File(imaris_file, "r") as f: - data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:] - points = f["/Scene/Content/Points0/CoordsXYZR"][:] + if point_key not in f: + print("Skipping", imaris_file, "due to missing annotations") + return + points = f[point_key][:] points = points[:, :-1] - points = points[:, ::-1] - # TODO crop the data to the original shape. - # Can we just crop the zero-padding ?! - crop_box = np.where(data != 0) - crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box) - data = data[crop_box] - print(data.shape) - - # Scale the points to match the image dimensions. - voxel_size = get_voxel_size(imaris_file) - points /= voxel_size[None] + g = f["/DataSet/ResolutionLevel 0/TimePoint 0"] + # The first channel is ctbp2 / the synapse marker channel. + data = g["Channel 0/Data"][:] + # The second channel is vglut / the ihc channel. + if "Channel 1" in g: + ihc_data = g["Channel 1/Data"][:] + else: + ihc_data = None + + T = get_transformation(imaris_file) + points = (T @ np.c_[points, np.ones(len(points))].T).T[:, :3] + points = points[:, ::-1] + + if crop: + crop_box = np.where(data != 0) + crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box) + data = data[crop_box] + + if tif_file is None: + original_data = None + else: + original_data = imageio.imread(tif_file) if output_folder is None: v = napari.Viewer() v.add_image(data) + if ihc_data is not None: + v.add_image(ihc_data) + if original_data is not None: + v.add_image(original_data, visible=False) v.add_points(points) v.title = os.path.basename(imaris_file) napari.run() @@ -59,6 +94,8 @@ def extract_training_data(imaris_file, output_folder): f = zarr.open(image_file, "a") f.create_dataset("raw", data=data) + if ihc_data is not None: + f.create_dataset("raw_ihc", data=ihc_data) # Files that look good for training: @@ -69,11 +106,96 @@ def extract_training_data(imaris_file, output_folder): # - 4.2R_apex_IHCribboncount_Z.ims # - 6.2R_apex_IHCribboncount_Z.ims (very small crop) # - 6.2R_base_IHCribbons_Z.ims -def main(): +def process_training_data_v1(): files = sorted(glob("./data/synapse_stains/*.ims")) for ff in files: extract_training_data(ff, output_folder="./training_data") +def _match_tif(imaris_file): + folder = os.path.split(imaris_file)[0] + + fname = os.path.basename(imaris_file) + parts = fname.split("_") + cochlea = parts[0].upper() + region = parts[1] + + tif_name = f"{cochlea}_{region}_CTBP2.tif" + tif_path = os.path.join(folder, tif_name) + assert os.path.exists(tif_path), tif_path + + return tif_path + + +def process_training_data_v2(visualize=True): + input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses" + + train_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v2" # noqa + test_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test/v2" # noqa + + train_folders = ["M78L_IHC-synapse_crops"] + test_folders = ["M226L_IHC-synapse_crops", "M226R_IHC-synapsecrops"] + + valid_files = [ + "m78l_apexp2718_cr-ctbp2.ims", + "m226r_apex_p1268_pv-ctbp2.ims", + "m226r_base_p800_vglut3-ctbp2.ims", + ] + + for folder in train_folders + test_folders: + + if visualize: + output_folder = None + elif folder in train_folders: + output_folder = train_output + os.makedirs(output_folder, exist_ok=True) + else: + output_folder = test_output + os.makedirs(output_folder, exist_ok=True) + + imaris_files = sorted(glob(os.path.join(input_root, folder, "*.ims"))) + for imaris_file in imaris_files: + if os.path.basename(imaris_file) not in valid_files: + continue + extract_training_data(imaris_file, output_folder, tif_file=None, crop=True, scale=True) + + +# We have fixed the imaris data extraction problem and can use all the crops! +def process_training_data_v3(visualize=True): + input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses" + + train_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v3" # noqa + test_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3" # noqa + + train_folders = ["synapse_stains", "M78L_IHC-synapse_crops", "M226R_IHC-synapsecrops"] + test_folders = ["M226L_IHC-synapse_crops"] + + exclude_names = ["220824_Ex3IL_rbCAST1635_mCtBP2580_chCR488_cell1_CtBP2spots.ims"] + + for folder in train_folders + test_folders: + + if visualize: + output_folder = None + elif folder in train_folders: + output_folder = train_output + os.makedirs(output_folder, exist_ok=True) + else: + output_folder = test_output + os.makedirs(output_folder, exist_ok=True) + + imaris_files = sorted(glob(os.path.join(input_root, folder, "*.ims"))) + for imaris_file in imaris_files: + if os.path.basename(imaris_file) in exclude_names: + print("Skipping", imaris_file) + continue + extract_training_data(imaris_file, output_folder, tif_file=None, crop=True) + + +def main(): + # process_training_data_v1() + # process_training_data_v2(visualize=True) + process_training_data_v3(visualize=False) + + if __name__ == "__main__": main() diff --git a/scripts/synapse_marker_detection/marker_detection.py b/scripts/synapse_marker_detection/marker_detection.py new file mode 100644 index 0000000..13e48ea --- /dev/null +++ b/scripts/synapse_marker_detection/marker_detection.py @@ -0,0 +1,47 @@ +import argparse + +import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.segmentation.synapse_detection import marker_detection + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.") + parser.add_argument("-s", "--mask", required=True, help="Path to IHC segmentation used for masking.") + parser.add_argument("-m", "--model", required=True, help="Path to synapse detection model.") + parser.add_argument("-k", "--input_key", default=None, + help="Input key for image data and mask data for marker detection.") + parser.add_argument("-d", "--max_distance", default=20, + help="The maximal distance for a valid match of synapse markers to IHCs.") + + parser.add_argument("--s3", action="store_true", help="Use S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + if args.s3: + input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + + mask_path, fs = s3_utils.get_s3_path(args.mask, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + else: + input_path = args.input + mask_path = args.mask + + marker_detection(input_path=input_path, input_key=args.input_key, mask_path=mask_path, + output_folder=args.output_folder, model_path=args.model) + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/run_prediction.py b/scripts/synapse_marker_detection/run_prediction.py index 1195f31..649362c 100644 --- a/scripts/synapse_marker_detection/run_prediction.py +++ b/scripts/synapse_marker_detection/run_prediction.py @@ -1,52 +1,43 @@ import argparse -import os -import pandas as pd -import numpy as np -import zarr - -from elf.parallel.local_maxima import find_local_maxima -from flamingo_tools.segmentation.unet_prediction import prediction_impl +import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.segmentation.synapse_detection import run_prediction def main(): parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", required=True) - parser.add_argument("-o", "--output_folder", required=True) - parser.add_argument("-m", "--model", required=True) - parser.add_argument("-k", "--input_key", default=None) + parser.add_argument("-i", "--input", required=True, help="Path to image data to be segmented.") + parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.") + parser.add_argument("-m", "--model", required=True, + help="Path to synapse detection model.") + parser.add_argument("-k", "--input_key", default=None, + help="The key / internal path to image data.") + + parser.add_argument("--s3", action="store_true", help="Use S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + args = parser.parse_args() block_shape = (64, 256, 256) halo = (16, 64, 64) - # Skip existing prediction, which is saved in output_folder/predictions.zarr - skip_prediction = False - output_path = os.path.join(args.output_folder, "predictions.zarr") - prediction_key = "prediction" - if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): - skip_prediction = True - - if not skip_prediction: - prediction_impl( - args.input, args.input_key, args.output_folder, args.model, - scale=None, block_shape=block_shape, halo=halo, - apply_postprocessing=False, output_channels=1, - ) - - detection_path = os.path.join(args.output_folder, "synapse_detection.tsv") - if not os.path.exists(detection_path): - input_ = zarr.open(output_path, "r")[prediction_key] - detections = find_local_maxima( - input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16, - ) - # Save the result in mobie compatible format. - detections = np.concatenate( - [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 - ) - detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) - detections.to_csv(detection_path, index=False, sep="\t") + if args.s3: + input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + + else: + input_path = args.input + + run_prediction(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder, + model_path=args.model, block_shape=block_shape, halo=halo) if __name__ == "__main__": diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index 2a7d6af..159ccf8 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -1,45 +1,42 @@ import os import sys +from glob import glob -from detection_dataset import DetectionDataset +from sklearn.model_selection import train_test_split +from detection_dataset import DetectionDataset, MinPointSampler sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") from utils.training.training import supervised_training # noqa -ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v1" # noqa +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v3" # noqa TRAIN_ROOT = os.path.join(ROOT, "images") LABEL_ROOT = os.path.join(ROOT, "labels") def get_paths(split): - file_names = [ - "4.1L_apex_IHCribboncount_Z", - "4.1L_base_IHCribbons_Z", - "4.1L_mid_IHCribboncount_Z", - "4.2R_apex_IHCribboncount_Z", - "4.2R_apex_IHCribboncount_Z", - "6.2R_apex_IHCribboncount_Z", - "6.2R_base_IHCribbons_Z", - ] - image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names] - label_paths = [os.path.join(LABEL_ROOT, f"{fname}.csv") for fname in file_names] + image_paths = sorted(glob(os.path.join(TRAIN_ROOT, "*.zarr"))) + label_paths = sorted(glob(os.path.join(LABEL_ROOT, "*.csv"))) + assert len(image_paths) == len(label_paths) + + train_images, val_images, train_labels, val_labels = train_test_split( + image_paths, label_paths, test_size=2, random_state=42 + ) if split == "train": - image_paths = image_paths[:-1] - label_paths = label_paths[:-1] + image_paths = train_images + label_paths = train_labels else: - image_paths = image_paths[-1:] - label_paths = label_paths[-1:] + image_paths = val_images + label_paths = val_labels return image_paths, label_paths -# TODO maybe add a sampler for the label data def train(): - model_name = "synapse_detection_v1" + model_name = "synapse_detection_v3" train_paths, train_label_paths = get_paths("train") val_paths, val_label_paths = get_paths("val") @@ -64,7 +61,7 @@ def train(): patch_shape=patch_shape, batch_size=batch_size, check=check, lr=1e-4, - n_iterations=int(5e4), + n_iterations=int(1e5), out_channels=1, augmentations=None, eps=1e-5, @@ -77,6 +74,7 @@ def train(): dataset_class=DetectionDataset, n_samples_train=3200, n_samples_val=160, + sampler=MinPointSampler(min_points=1, p_reject=0.8), ) diff --git a/scripts/validation/IHCs/compare_annotations.py b/scripts/validation/IHCs/compare_annotations.py new file mode 100644 index 0000000..370a262 --- /dev/null +++ b/scripts/validation/IHCs/compare_annotations.py @@ -0,0 +1,81 @@ +import os +from glob import glob + +import napari +import pandas as pd +import tifffile + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs" +# ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationsLR"] +ANNOTATION_FOLDERS = ["Annotations_AMD", "Annotations_LR"] +COLOR = ["green", "yellow", "orange"] + + +def _match_annotations(image_path): + prefix = os.path.basename(image_path).split("_")[:3] + prefix = "_".join(prefix) + + annotations = {} + for annotation_folder in ANNOTATION_FOLDERS: + all_annotations = glob(os.path.join(ROOT, annotation_folder, "*.csv")) + matches = [ann for ann in all_annotations if os.path.basename(ann).startswith(prefix)] + if len(matches) == 0: + continue + assert len(matches) == 1 + annotation_path = matches[0] + + annotation = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values + annotations[annotation_folder] = annotation + + return annotations + + +def compare_annotations(image_path): + annotations = _match_annotations(image_path) + + image = tifffile.memmap(image_path) + v = napari.Viewer() + v.add_image(image) + for i, (name, annotation) in enumerate(annotations.items()): + v.add_points(annotation, name=name, face_color=COLOR[i]) + v.title = os.path.basename(image_path) + napari.run() + + +def visualize(image_paths): + for image_path in image_paths: + compare_annotations(image_path) + + +def check_annotations(image_paths): + annotation_status = {"file": []} + annotation_status.update({ann: [] for ann in ANNOTATION_FOLDERS}) + for image_path in image_paths: + annotations = _match_annotations(image_path) + annotation_status["file"].append(os.path.basename(image_path)) + for ann in ANNOTATION_FOLDERS: + annotation_status[ann].append("Yes" if ann in annotations else "No") + annotation_status = pd.DataFrame(annotation_status) + print(annotation_status) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--images", nargs="+") + parser.add_argument("--check", action="store_true") + args = parser.parse_args() + + if args.images is None: + image_paths = sorted(glob(os.path.join(ROOT, "*.tif"))) + else: + image_paths = args.images + + if args.check: + check_annotations(image_paths) + else: + visualize(image_paths) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/IHCs/run_evaluation.py b/scripts/validation/IHCs/run_evaluation.py index 15c2afe..d87e5ae 100644 --- a/scripts/validation/IHCs/run_evaluation.py +++ b/scripts/validation/IHCs/run_evaluation.py @@ -3,11 +3,12 @@ import pandas as pd from flamingo_tools.validation import ( - fetch_data_for_evaluation, parse_annotation_path, compute_scores_for_annotated_slice + fetch_data_for_evaluation, _parse_annotation_path, compute_scores_for_annotated_slice ) ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs" -ANNOTATION_FOLDERS = ["Annotations_LR"] +# ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationsLR"] +ANNOTATION_FOLDERS = ["Annotations_AMD", "Annotations_LR"] def run_evaluation(root, annotation_folders, result_file, cache_folder): @@ -28,7 +29,7 @@ def run_evaluation(root, annotation_folders, result_file, cache_folder): annotations = sorted(glob(os.path.join(root, folder, "*.csv"))) for annotation_path in annotations: print(annotation_path) - cochlea, slice_id = parse_annotation_path(annotation_path) + cochlea, slice_id = _parse_annotation_path(annotation_path) # For the cochlea M_LR_000226_R the actual component is 2, not 1 component = 2 if "226_R" in cochlea else 1 diff --git a/scripts/validation/IHCs/visualize_validation.py b/scripts/validation/IHCs/visualize_validation.py index f33b96f..0b37345 100644 --- a/scripts/validation/IHCs/visualize_validation.py +++ b/scripts/validation/IHCs/visualize_validation.py @@ -6,7 +6,7 @@ import tifffile from flamingo_tools.validation import ( - fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization, parse_annotation_path + fetch_data_for_evaluation, compute_matches_for_annotated_slice, for_visualization, _parse_annotation_path ) ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs" @@ -17,15 +17,15 @@ def _match_image_path(annotation_path): prefix = os.path.basename(annotation_path).split("_")[:-3] prefix = "_".join(prefix) matches = [path for path in all_files if os.path.basename(path).startswith(prefix)] - # if len(matches) != 1: - # breakpoint() + if len(matches) != 1: + breakpoint() assert len(matches) == 1, f"{prefix}: {len(matches)}" return matches[0] def visualize_anotation(annotation_path, cache_folder): print("Checking", annotation_path) - cochlea, slice_id = parse_annotation_path(annotation_path) + cochlea, slice_id = _parse_annotation_path(annotation_path) cache_path = None if cache_folder is None else os.path.join(cache_folder, f"{cochlea}_{slice_id}.tif") image_path = _match_image_path(annotation_path) diff --git a/scripts/validation/SGNs/analyze.py b/scripts/validation/SGNs/analyze.py index 4a5ea94..3bb9085 100644 --- a/scripts/validation/SGNs/analyze.py +++ b/scripts/validation/SGNs/analyze.py @@ -24,9 +24,15 @@ def compute_scores(table, annotator=None): def main(): parser = argparse.ArgumentParser() parser.add_argument("result_file") + parser.add_argument("--all", action="store_true") args = parser.parse_args() table = pd.read_csv(args.result_file) + if args.all: + print(table) + print() + print() + annotators = pd.unique(table.annotator) results = [] diff --git a/scripts/validation/SGNs/compare_annotations.py b/scripts/validation/SGNs/compare_annotations.py index 2b2b4bc..b262d2e 100644 --- a/scripts/validation/SGNs/compare_annotations.py +++ b/scripts/validation/SGNs/compare_annotations.py @@ -6,8 +6,8 @@ import tifffile ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" -ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationLR"] -COLOR = ["green", "yellow", "orange"] +ANNOTATION_FOLDERS = ["AnnotationsAMD", "AnnotationsEK", "AnnotationsLR"] +COLOR = ["blue", "yellow", "orange"] def _match_annotations(image_path): @@ -18,8 +18,9 @@ def _match_annotations(image_path): for annotation_folder in ANNOTATION_FOLDERS: all_annotations = glob(os.path.join(ROOT, annotation_folder, "*.csv")) matches = [ann for ann in all_annotations if os.path.basename(ann).startswith(prefix)] - if len(matches) != 1: + if len(matches) == 0: continue + assert len(matches) == 1 annotation_path = matches[0] annotation = pd.read_csv(annotation_path)[["axis-0", "axis-1", "axis-2"]].values @@ -40,10 +41,28 @@ def compare_annotations(image_path): napari.run() +def visualize(image_paths): + for image_path in image_paths: + compare_annotations(image_path) + + +def check_annotations(image_paths): + annotation_status = {"file": []} + annotation_status.update({ann: [] for ann in ANNOTATION_FOLDERS}) + for image_path in image_paths: + annotations = _match_annotations(image_path) + annotation_status["file"].append(os.path.basename(image_path)) + for ann in ANNOTATION_FOLDERS: + annotation_status[ann].append("Yes" if ann in annotations else "No") + annotation_status = pd.DataFrame(annotation_status) + print(annotation_status) + + def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--images", nargs="+") + parser.add_argument("--check", action="store_true") args = parser.parse_args() if args.images is None: @@ -51,8 +70,10 @@ def main(): else: image_paths = args.images - for image_path in image_paths: - compare_annotations(image_path) + if args.check: + check_annotations(image_paths) + else: + visualize(image_paths) if __name__ == "__main__": diff --git a/scripts/validation/SGNs/consensus_annotations.py b/scripts/validation/SGNs/consensus_annotations.py new file mode 100644 index 0000000..8316bf2 --- /dev/null +++ b/scripts/validation/SGNs/consensus_annotations.py @@ -0,0 +1,93 @@ +import os +from glob import glob + +import numpy as np +import pandas as pd +from flamingo_tools.validation import create_consensus_annotations + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" +ANNOTATION_FOLDERS = ["AnnotationsAMD", "AnnotationsEK", "AnnotationsLR"] +COLOR = ["blue", "yellow", "orange"] +OUTPUT_FOLDER = os.path.join(ROOT, "Consensus") + + +def match_annotations(image_path): + annotations = {} + prefix = os.path.basename(image_path).split("_")[:3] + prefix = "_".join(prefix) + + annotations = {} + for annotation_folder in ANNOTATION_FOLDERS: + all_annotations = glob(os.path.join(ROOT, annotation_folder, "*.csv")) + matches = [ann for ann in all_annotations if os.path.basename(ann).startswith(prefix)] + assert len(matches) == 1 + annotation_path = matches[0] + annotations[annotation_folder] = annotation_path + + return annotations + + +def consensus_annotations(image_path, check): + print("Compute consensus annotations for", image_path) + annotation_paths = match_annotations(image_path) + consensus_annotations, unmatched_annotations = create_consensus_annotations( + annotation_paths, matching_distance=8.0, min_matches_for_consensus=2, + ) + fname = os.path.basename(image_path) + + if check: + import napari + import tifffile + + consensus_annotations = consensus_annotations[["axis-0", "axis-1", "axis-2"]].values + unmatched_annotators = unmatched_annotations.annotator.values + unmatched_annotations = unmatched_annotations[["axis-0", "axis-1", "axis-2"]].values + + image = tifffile.imread(image_path) + v = napari.Viewer() + v.add_image(image) + v.add_points(consensus_annotations, face_color="green") + v.add_points( + unmatched_annotations, + properties={"annotator": unmatched_annotators}, + face_color="annotator", + face_color_cycle=COLOR, # TODO reorder + ) + v.title = os.path.basename(fname) + napari.run() + + else: + # Save the combined consensus and unmatched annotation. + combined_annotations = consensus_annotations[["axis-0", "axis-1", "axis-2", "annotators"]] + combined_annotations["annotators"] = "consensus" + combined_annotations = combined_annotations.rename(columns={"annotators": "annotator"}) + combined_annotations = pd.concat([combined_annotations, unmatched_annotations]) + + print("Saving consensus annotations for", fname, ":") + for name, count in zip(*np.unique(combined_annotations.annotator.values, return_counts=True)): + print(name, count) + + os.makedirs(OUTPUT_FOLDER, exist_ok=True) + output_path = os.path.join(OUTPUT_FOLDER, fname.replace(".tif", ".csv")) + combined_annotations.to_csv(output_path, index=False) + + +# NOTE: we need to treat the rescaled ones differently. +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--images", nargs="+") + parser.add_argument("--check", action="store_true") + args = parser.parse_args() + + if args.images is None: + image_paths = sorted(glob(os.path.join(ROOT, "*.tif"))) + else: + image_paths = args.images + + for image_path in image_paths: + consensus_annotations(image_path, args.check) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/SGNs/proofread_consensus_annotations.py b/scripts/validation/SGNs/proofread_consensus_annotations.py new file mode 100644 index 0000000..c92522d --- /dev/null +++ b/scripts/validation/SGNs/proofread_consensus_annotations.py @@ -0,0 +1,58 @@ +import os +from glob import glob + +import napari +import pandas as pd +import tifffile + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" +CONSENSUS_FOLDER = os.path.join(ROOT, "Consensus") +COLOR = ["blue", "yellow", "orange"] + + +def proofread_consensus_annotations(image_path, annotation_path, color_by_annotator): + image = tifffile.memmap(image_path) + annotations = pd.read_csv(annotation_path) + + consensus_annotations = annotations[annotations.annotator == "consensus"][["axis-0", "axis-1", "axis-2"]].values + unmatched_annotations = annotations[annotations.annotator != "consensus"] + + unmatched_annotators = unmatched_annotations.annotator.values + unmatched_annotations = unmatched_annotations[["axis-0", "axis-1", "axis-2"]].values + + image = tifffile.imread(image_path) + v = napari.Viewer() + v.add_image(image) + v.add_points(consensus_annotations, face_color="green") + if color_by_annotator: + v.add_points( + unmatched_annotations, + properties={"annotator": unmatched_annotators}, + face_color="annotator", + face_color_cycle=COLOR, # TODO reorder + ) + else: + v.add_points(unmatched_annotations) + fname = os.path.basename(annotation_path) + v.title = os.path.basename(fname) + napari.run() + + +# TODO enable skipping the ones already stored in the output folder and specifying a different root path +# TODO set reasonable contrast limits +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--color_by_annotator", action="store_true") + args = parser.parse_args() + + annotations = sorted(glob(os.path.join(CONSENSUS_FOLDER, "*.csv"))) + for annotation_path in annotations: + fname = os.path.basename(annotation_path) + image_path = os.path.join(ROOT, fname.replace(".csv", ".tif")) + proofread_consensus_annotations(image_path, annotation_path, args.color_by_annotator) + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/SGNs/rescale_annotations.py b/scripts/validation/SGNs/rescale_annotations.py new file mode 100644 index 0000000..d0d22b1 --- /dev/null +++ b/scripts/validation/SGNs/rescale_annotations.py @@ -0,0 +1,84 @@ +import os +import shutil +from glob import glob + +import numpy as np +import pandas as pd +import tifffile +import zarr + +from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT + + +def get_scale_factor(): + original_path = "/mnt/ceph-hdd/cold/nim00007/cochlea-lightsheet/M_LR_000169_R/MLR000169R_PV.tif" + original_shape = tifffile.memmap(original_path).shape + + cochlea = "M_LR_000169_R" + internal_path = os.path.join(cochlea, "images", "ome-zarr", "SGN_v2.ome.zarr") + s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT) + + input_key = "s0" + with zarr.open(s3_store, mode="r") as f: + new_shape = f[input_key].shape + + scale_factor = tuple( + float(nsh) / float(osh) for nsh, osh in zip(new_shape, original_shape) + ) + return scale_factor + + +def rescale_annotations(input_path, scale_factor, bkp_folder): + annotations = pd.read_csv(input_path) + + annotations_rescaled = annotations.copy() + annotations_rescaled["axis-1"] = annotations["axis-1"] * scale_factor[1] + annotations_rescaled["axis-2"] = annotations["axis-2"] * scale_factor[2] + + fname = os.path.basename(input_path) + name_components = fname.split("_") + z = int(name_components[2][1:]) + new_z = int(np.round(z * scale_factor[0])) + + name_components[2] = f"z{new_z}" + name_components = name_components[:-1] + ["rescaled"] + name_components[-1:] + new_fname = "_".join(name_components) + + input_folder = os.path.split(input_path)[0] + out_path = os.path.join(input_folder, new_fname) + bkp_path = os.path.join(bkp_folder, fname) + + # print(input_path) + # print(out_path) + # print(bkp_path) + # print() + # return + + shutil.move(input_path, bkp_path) + annotations_rescaled.to_csv(out_path, index=False) + + +def main(): + # scale_factor = get_scale_factor() + # print(scale_factor) + scale_factor = (2.6314,) * 3 + + root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" + annotation_folders = ["AnnotationsEK", "AnnotationsAMD", "AnnotationsLR"] + for folder in annotation_folders: + bkp_folder = os.path.join(root, folder, "rescaled_bkp") + os.makedirs(bkp_folder, exist_ok=True) + + files = glob(os.path.join(root, folder, "*.csv")) + for annotation_file in files: + fname = os.path.basename(annotation_file) + if not fname.startswith(("MLR169R_PV_z722", "MLR169R_PV_z979")): + continue + print("Rescaling", annotation_file) + rescale_annotations(annotation_file, scale_factor, bkp_folder) + + +# Rescale the point annotations for the cochlea MLR169R, which was +# annotated at the original scale, but then rescaled for segmentation. +if __name__ == "__main__": + main() diff --git a/scripts/validation/SGNs/run_evaluation.py b/scripts/validation/SGNs/run_evaluation.py index 2153c24..fa65017 100644 --- a/scripts/validation/SGNs/run_evaluation.py +++ b/scripts/validation/SGNs/run_evaluation.py @@ -7,7 +7,7 @@ ) ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationSGNs" -ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationLR"] +ANNOTATION_FOLDERS = ["AnnotationsEK", "AnnotationsAMD", "AnnotationsLR"] def run_evaluation(root, annotation_folders, result_file, cache_folder): @@ -27,11 +27,7 @@ def run_evaluation(root, annotation_folders, result_file, cache_folder): annotator = folder[len("Annotations"):] annotations = sorted(glob(os.path.join(root, folder, "*.csv"))) for annotation_path in annotations: - print(annotation_path) cochlea, slice_id = parse_annotation_path(annotation_path) - # We don't have this cochlea in MoBIE yet - if cochlea == "M_LR_000169_R": - continue print("Run evaluation for", annotator, cochlea, "z=", slice_id) segmentation, annotations = fetch_data_for_evaluation( diff --git a/scripts/validation/SGNs/visualize_validation.py b/scripts/validation/SGNs/visualize_validation.py index 2f8c23e..0a20de0 100644 --- a/scripts/validation/SGNs/visualize_validation.py +++ b/scripts/validation/SGNs/visualize_validation.py @@ -30,8 +30,11 @@ def visualize_anotation(annotation_path, cache_folder): image_path = _match_image_path(annotation_path) + # For debugging. + components = [1] + # components = None segmentation, annotations = fetch_data_for_evaluation( - annotation_path, cache_path=cache_path, components_for_postprocessing=[1], + annotation_path, cache_path=cache_path, components_for_postprocessing=components, ) image = tifffile.memmap(image_path) diff --git a/scripts/validation/synapses/prediction.py b/scripts/validation/synapses/prediction.py new file mode 100644 index 0000000..2ea3273 --- /dev/null +++ b/scripts/validation/synapses/prediction.py @@ -0,0 +1,155 @@ +import os +import sys +from glob import glob +from pathlib import Path + +import numpy as np +import pandas as pd + +from elf.io import open_file +from elf.parallel.local_maxima import find_local_maxima +from flamingo_tools.segmentation.unet_prediction import prediction_impl, run_unet_prediction + +INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/images" # noqa +GT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/labels" +OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/SynapseValidation" + +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") +sys.path.append("../../synapse_marker_detection") + + +def pred_synapse_impl(input_path, output_folder): + model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/synapse_marker_detection/checkpoints/synapse_detection_v3" # noqa + input_key = "raw" + + block_shape = (32, 128, 128) + halo = (16, 64, 64) + + prediction_impl( + input_path, input_key, output_folder, model_path, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + output_path = os.path.join(output_folder, "predictions.zarr") + prediction_key = "prediction" + input_ = open_file(output_path, "r")[prediction_key] + + detections = find_local_maxima( + input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=4, + ) + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + + detection_path = os.path.join(output_folder, "synapse_detection.tsv") + detections.to_csv(detection_path, index=False, sep="\t") + + +def predict_synapses(): + files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr"))) + for ff in files: + print("Segmenting", ff) + output_folder = os.path.join(OUTPUT_ROOT, Path(ff).stem) + pred_synapse_impl(ff, output_folder) + + +def pred_ihc_impl(input_path, output_folder): + model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v2_cochlea_distance_unet_IHC_supervised_2025-05-21" # noqa + + run_unet_prediction( + input_path, input_key="raw_ihc", output_folder=output_folder, model_path=model_path, min_size=1000, + seg_class="ihc", center_distance_threshold=0.5, boundary_distance_threshold=0.5, + ) + + +def predict_ihcs(): + files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr"))) + for ff in files: + print("Segmenting", ff) + output_folder = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc") + pred_ihc_impl(ff, output_folder) + + +def _filter_synapse_impl(detections, ihc_file, output_path): + from flamingo_tools.segmentation.synapse_detection import map_and_filter_detections + + with open_file(ihc_file, mode="r") as f: + if "segmentation_filtered" in f: + print("Using filtered segmentation!") + segmentation = open_file(ihc_file)["segmentation_filtered"][:] + else: + segmentation = open_file(ihc_file)["segmentation"][:] + + max_distance = 5 # 5 micrometer + filtered_detections = map_and_filter_detections(segmentation, detections, max_distance=max_distance) + filtered_detections.to_csv(output_path, index=False, sep="\t") + + +def filter_synapses(): + input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr"))) + for ff in input_files: + ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr") + output_folder = os.path.join(OUTPUT_ROOT, Path(ff).stem) + synapses = os.path.join(output_folder, "synapse_detection.tsv") + synapses = pd.read_csv(synapses, sep="\t") + output_path = os.path.join(output_folder, "filtered_synapse_detection.tsv") + _filter_synapse_impl(synapses, ihc, output_path) + + +def filter_gt(): + input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr"))) + gt_files = sorted(glob(os.path.join(GT_ROOT, "*.csv"))) + for ff, gt in zip(input_files, gt_files): + ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr") + output_folder, fname = os.path.split(gt) + output_path = os.path.join(output_folder, fname.replace(".csv", "_filtered.tsv")) + + gt = pd.read_csv(gt) + gt = gt.rename(columns={"axis-0": "z", "axis-1": "y", "axis-2": "x"}) + gt.insert(0, "spot_id", np.arange(1, len(gt) + 1)) + + _filter_synapse_impl(gt, ihc, output_path) + + +def _check_prediction(input_file, ihc_file, detection_file): + import napari + + synapses = pd.read_csv(detection_file, sep="\t")[["z", "y", "x"]].values + + vglut = open_file(input_file)["raw_ihc"][:] + ctbp2 = open_file(input_file)["raw"][:] + ihcs = open_file(ihc_file)["segmentation"][:] + + v = napari.Viewer() + v.add_image(vglut) + v.add_image(ctbp2) + v.add_labels(ihcs) + v.add_points(synapses) + napari.run() + + +def check_predictions(): + input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr"))) + for ff in input_files: + ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr") + synapses = os.path.join(OUTPUT_ROOT, Path(ff).stem, "filtered_synapse_detection.tsv") + _check_prediction(ff, ihc, synapses) + + +def process_everything(): + predict_synapses() + predict_ihcs() + filter_synapses() + filter_gt() + + +def main(): + process_everything() + # check_predictions() + + +if __name__ == "__main__": + main() diff --git a/scripts/validation/synapses/run_evaluation.py b/scripts/validation/synapses/run_evaluation.py new file mode 100644 index 0000000..3407d1a --- /dev/null +++ b/scripts/validation/synapses/run_evaluation.py @@ -0,0 +1,170 @@ +import os + +import numpy as np +import pandas as pd + +from elf.io import open_file +from scipy.spatial import cKDTree +from scipy.optimize import linear_sum_assignment + + +# TODO refactor +def match_detections( + detections: np.ndarray, + annotations: np.ndarray, + max_dist: float +): + """One-to-one matching between 3-D detections and ground-truth points. + + Args: + detections: (N, 3) array-like Candidate points produced by the model. + annotations: (M, 3) array-like Ground-truth reference points. + max_dist: Maximum Euclidean distance allowed for a match. + + Returns: + tp_det_ids : 1-D ndarray. Indices in `detections` that were matched (true positives). + tp_ann_ids : 1-D ndarray. Indices in `annotations` that were matched (true positives). + fp_det_ids : 1-D ndarray. Unmatched detection indices (false positives). + fn_ann_ids : 1-D ndarray, Unmatched annotation indices (false negatives). + """ + det = np.asarray(detections, dtype=float) + ann = np.asarray(annotations, dtype=float) + N, M = len(det), len(ann) + + # trivial corner cases -------------------------------------------------------- + if N == 0: + return np.empty(0, int), np.empty(0, int), np.empty(0, int), np.arange(M) + if M == 0: + return np.empty(0, int), np.empty(0, int), np.arange(N), np.empty(0, int) + + # 1. build sparse radius-filtered distance matrix ----------------------------- + tree_det = cKDTree(det) + tree_ann = cKDTree(ann) + coo = tree_det.sparse_distance_matrix(tree_ann, max_dist, output_type="coo_matrix") + + if coo.nnz == 0: # nothing is close enough + return np.empty(0, int), np.empty(0, int), np.arange(N), np.arange(M) + + cost = np.full((N, M), 5 * max_dist, dtype=float) + cost[coo.row, coo.col] = coo.data # fill only existing edges + + # 2. optimal one-to-one assignment (Hungarian) -------------------------------- + row_ind, col_ind = linear_sum_assignment(cost) + + # Filter assignments that were padded with +∞ cost for non-existent edges + # (linear_sum_assignment automatically does that padding internally). + valid_mask = cost[row_ind, col_ind] <= max_dist + tp_det_ids = row_ind[valid_mask] + tp_ann_ids = col_ind[valid_mask] + assert len(tp_det_ids) == len(tp_ann_ids) + + # 3. derive FP / FN ----------------------------------------------------------- + fp_det_ids = np.setdiff1d(np.arange(N), tp_det_ids, assume_unique=True) + fn_ann_ids = np.setdiff1d(np.arange(M), tp_ann_ids, assume_unique=True) + + return tp_det_ids, tp_ann_ids, fp_det_ids, fn_ann_ids + + +def evaluate_synapse_detections(pred, gt): + fname = os.path.basename(gt) + + pred = pd.read_csv(pred, sep="\t")[["z", "y", "x"]].values + gt = pd.read_csv(gt, sep="\t")[["z", "y", "x"]].values + tps_pred, tps_gt, fps, fns = match_detections(pred, gt, max_dist=3) + + return pd.DataFrame({ + "name": [fname], "tp": [len(tps_pred)], "fp": [len(fps)], "fn": [len(fns)], + }) + + +def run_evaluation(pred_files, gt_files): + results = [] + for pred, gt in zip(pred_files, gt_files): + res = evaluate_synapse_detections(pred, gt) + results.append(res) + results = pd.concat(results) + + tp = results.tp.sum() + fp = results.fp.sum() + fn = results.fn.sum() + + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1_score = 2 * precision * recall / (precision + recall) + + print("All results:") + print(results) + print("Evaluation:") + print("Precision:", precision) + print("Revall:", recall) + print("F1-Score:", f1_score) + + +def visualize_synapse_detections(pred, gt, heatmap_path=None, ctbp2_path=None): + import napari + + fname = os.path.basename(gt) + + pred = pd.read_csv(pred, sep="\t")[["z", "y", "x"]].values + gt = pd.read_csv(gt, sep="\t")[["z", "y", "x"]].values + tps_pred, tps_gt, fps, fns = match_detections(pred, gt, max_dist=5) + + tps = pred[tps_pred] + fps = pred[fps] + fns = gt[fns] + + if heatmap_path is None: + heatmap = None + else: + heatmap = open_file(heatmap_path)["prediction"][:] + + if ctbp2_path is None: + ctbp2 = None + else: + ctbp2 = open_file(ctbp2_path)["raw"][:] + + v = napari.Viewer() + if ctbp2 is not None: + v.add_image(ctbp2) + if heatmap is not None: + v.add_image(heatmap) + v.add_points(pred, visible=False) + v.add_points(gt, visible=False) + v.add_points(tps, name="TPS", face_color="green") + v.add_points(fps, name="FPs", face_color="orange") + v.add_points(fns, name="FNs", face_color="yellow") + v.title = f"{fname}: tps={len(tps)}, fps={len(fps)}, fns={len(fns)}" + napari.run() + + +def visualize_evaluation(pred_files, gt_files, ctbp2_files): + for pred, gt, ctbp2 in zip(pred_files, gt_files, ctbp2_files): + pred_folder = os.path.split(pred)[0] + heatmap = os.path.join(pred_folder, "predictions.zarr") + visualize_synapse_detections(pred, gt, heatmap, ctbp2) + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--visualize", action="store_true") + args = parser.parse_args() + + pred_files = [ + "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/SynapseValidation/m226l_midp330_vglut3-ctbp2/filtered_synapse_detection.tsv", # noqa + ] + gt_files = [ + "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/labels/m226l_midp330_vglut3-ctbp2_filtered.tsv", # noqa + ] + ctbp2_files = [ + "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/images/m226l_midp330_vglut3-ctbp2.zarr", # noqa + ] + + if args.visualize: + visualize_evaluation(pred_files, gt_files, ctbp2_files) + else: + run_evaluation(pred_files, gt_files) + + +if __name__ == "__main__": + main()