From d748003bfd9a002f092090121ad116dedc583322 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 2 Sep 2025 18:26:49 +0200 Subject: [PATCH 01/13] Add intensity masking script --- scripts/intensity_masking.py | 50 ++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 scripts/intensity_masking.py diff --git a/scripts/intensity_masking.py b/scripts/intensity_masking.py new file mode 100644 index 0000000..4cf12dc --- /dev/null +++ b/scripts/intensity_masking.py @@ -0,0 +1,50 @@ +import argparse + +import imageio.v3 as imageio +import numpy as np +from scipy.ndimage import binary_dilation, binary_closing, distance_transform_edt + + +def intensity_masking(image_path, seg_path, out_path, modulation_strength=10, dilation=2, view=False): + seg = imageio.imread(seg_path) + mask = binary_dilation(seg != 0, iterations=2) + mask = binary_closing(mask, iterations=4) + + image = imageio.imread(image_path) + lo, hi = np.percentile(image, 2), np.percentile(image, 98) + print(lo, hi) + image_modulated = np.clip(image, lo, hi).astype("float32") + image_modulated -= lo + image_modulated /= image_modulated.max() + + modulation_mask = distance_transform_edt(~mask) + modulation_mask /= modulation_mask.max() + modulation_mask = 1 - modulation_mask + modulation_mask[mask] = 1 + modulation_mask = np.pow(modulation_mask, 3) + modulation_mask *= modulation_strength + image_modulated *= modulation_mask + + if view: + import napari + v = napari.Viewer() + v.add_image(modulation_mask) + v.add_image(image, visible=False) + v.add_image(image_modulated) + v.add_labels(mask, visible=False) + napari.run() + return + imageio.imwrite(out_path, image_modulated, compression="zlib") + + +# image_path = "M_LR_000227_R/scale3/PV.tif" +# seg_path = "M_LR_000227_R/scale3/SGN_v2.tif" +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("image_path") + parser.add_argument("seg_path") + parser.add_argument("out_path") + parser.add_argument("--view", "-v", action="store_true") + parser.add_argument("--dilation", type=int, default=2) + args = parser.parse_args() + intensity_masking(args.image_path, args.seg_path, args.out_path, view=args.view, dilation=args.dilation) From de131814342837ca2fd115d9bba49f8f6e8e3906 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 3 Sep 2025 10:17:52 +0200 Subject: [PATCH 02/13] Add stardist postprocessing script --- scripts/la-vision/postprocess_stardist.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 scripts/la-vision/postprocess_stardist.py diff --git a/scripts/la-vision/postprocess_stardist.py b/scripts/la-vision/postprocess_stardist.py new file mode 100644 index 0000000..34eeab1 --- /dev/null +++ b/scripts/la-vision/postprocess_stardist.py @@ -0,0 +1,29 @@ +import os +from glob import glob + +from tifffile import imread, imwrite +from csbdeep.utils import normalize +from stardist.models import StarDist3D + +model = StarDist3D.from_pretrained("3D_demo") + + +def segment_with_stardist(ff, out): + axis_norm = (0, 1, 2) # normalize channels independently + x = imread(ff) + img = normalize(x[0], 1, 99.8, axis=axis_norm) + labels, details = model.predict_instances(img) + imwrite(out, labels) + + +def main(): + files = glob("predictions/sgn-new/*.tif") + out_folder = "./predictions/stardist" + os.makedirs(out_folder, exist_ok=True) + for ff in files: + out = imread(ff) + segment_with_stardist(ff, out) + + +if __name__ == "__main__": + main() From 2d929c4b3b02eac152dfe8835d7d93fdd416e40f Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 3 Sep 2025 12:48:42 +0200 Subject: [PATCH 03/13] Fixed output path --- scripts/la-vision/postprocess_stardist.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/scripts/la-vision/postprocess_stardist.py b/scripts/la-vision/postprocess_stardist.py index 34eeab1..6a75e61 100644 --- a/scripts/la-vision/postprocess_stardist.py +++ b/scripts/la-vision/postprocess_stardist.py @@ -17,12 +17,18 @@ def segment_with_stardist(ff, out): def main(): - files = glob("predictions/sgn-new/*.tif") - out_folder = "./predictions/stardist" + direc = "/mnt/vast-nhr/projects/nim00007/data/moser/predictions/sgn-new" + out_folder = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/stardist" + + files = [entry.path for entry in os.scandir(direc) if ".tif" in entry.name] + files.sort() + file_names = [entry.name.split(".tif")[0] for entry in os.scandir(direc) if ".tif" in entry.name] + file_names.sort() + os.makedirs(out_folder, exist_ok=True) - for ff in files: - out = imread(ff) - segment_with_stardist(ff, out) + for f_path, f_name in zip(files, file_names): + out = os.path.join(out_folder, f"{f_name}_seg.tif") + segment_with_stardist(f_path, out) if __name__ == "__main__": From cab37c0eeba529e0bbe67be5d4049a3fcdcc69c9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 8 Sep 2025 22:23:40 +0200 Subject: [PATCH 04/13] Implement splitting of non-convex objects --- flamingo_tools/segmentation/postprocessing.py | 119 +++++++++++++++- scripts/la-vision/fix_segmentation.py | 129 ++++++++++++++++++ 2 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 scripts/la-vision/fix_segmentation.py diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 07e5913..edf8b5a 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,7 +1,8 @@ import math import multiprocessing as mp +import threading from concurrent import futures -from typing import Callable, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import elf.parallel as parallel import numpy as np @@ -15,6 +16,9 @@ from scipy.spatial import distance from scipy.spatial import cKDTree, ConvexHull from skimage import measure +from skimage.filters import gaussian +from skimage.feature import peak_local_max +from skimage.segmentation import find_boundaries, watershed from sklearn.neighbors import NearestNeighbors from tqdm import tqdm @@ -734,3 +738,116 @@ def filter_cochlea_volume( combined_dilated[combined_dilated > 0] = 1 return combined_dilated + + +def split_nonconvex_objects( + segmentation: np.typing.ArrayLike, + output: np.typing.ArrayLike, + segmentation_table: pd.DataFrame, + min_size: int, + resolution: Union[float, Sequence[float]], + height_map: Optional[np.typing.ArrayLike] = None, + component_labels: Optional[List[int]] = None, + n_threads: Optional[int] = None, +) -> Dict[int, List[int]]: + """Split noncovex objects into multiple parts inplace. + + Args: + segmentation: + output: + segmentation_table: + min_size: + resolution: + height_map: + component_labels: + n_threads: + """ + if isinstance(resolution, float): + resolution = [resolution] * 3 + assert len(resolution) == 3 + resolution = np.array(resolution) + + lock = threading.Lock() + offset = len(segmentation_table) + + def split_object(object_id): + nonlocal offset + + row = segmentation_table[segmentation_table.label_id == object_id] + if min_size and row.n_pixels.values[0] < min_size: + return [object_id] + + bb_min = np.array([ + row.bb_min_z.values[0], row.bb_min_y.values[0], row.bb_min_x.values[0], + ]) / resolution + bb_max = np.array([ + row.bb_max_z.values[0], row.bb_max_y.values[0], row.bb_max_x.values[0], + ]) / resolution + + bb_min = np.maximum(bb_min.astype(int) - 1, np.array([0, 0, 0])) + bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape))) + bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max)) + + seg = segmentation[bb] + mask = ~find_boundaries(seg) + dist = distance_transform_edt(mask, sampling=resolution) + + seg_mask = seg == object_id + dist[~seg_mask] = 0 + dist = gaussian(dist, (0.6, 1.2, 1.2)) + maxima = peak_local_max(dist, min_distance=3, exclude_border=True) + + if len(maxima) == 1: + return [object_id] + + with lock: + old_offset = offset + offset += len(maxima) + + seeds = np.zeros(seg.shape, dtype=int) + for i, pos in enumerate(maxima, 1): + seeds[tuple(pos)] = old_offset + i + + if height_map is None: + hmap = dist.max() - dist + else: + hmap = height_map[bb] + new_seg = watershed(hmap, markers=seeds, mask=seg_mask) + + seg_ids, sizes = np.unique(new_seg, return_counts=True) + seg_ids, sizes = seg_ids[1:], sizes[1:] + + keep_ids = seg_ids[sizes > min_size] + if len(keep_ids) < 2: + return [object_id] + + elif len(keep_ids) != len(seg_ids): + new_seg[~np.isin(new_seg, keep_ids)] = 0 + new_seg = watershed(hmap, markers=new_seg, mask=seg_mask) + + output[bb][seg_mask] = new_seg[seg_mask] + return seg_ids.tolist() + + # import napari + # v = napari.Viewer() + # v.add_image(hmap) + # v.add_labels(seg) + # v.add_labels(new_seg) + # v.add_points(maxima) + # napari.run() + + if component_labels is None: + object_ids = segmentation_table.label_id.values + else: + object_ids = segmentation_table[segmentation_table.isin(component_labels)].label_id.values + + if n_threads is None: + n_threads = mp.cpu_count() + + with futures.ThreadPoolExecutor(n_threads) as tp: + new_id_mapping = list( + tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects") + ) + + new_id_mapping = {object_id: mapped_ids for object_id, mapped_ids in zip(object_ids, new_id_mapping)} + return new_id_mapping diff --git a/scripts/la-vision/fix_segmentation.py b/scripts/la-vision/fix_segmentation.py new file mode 100644 index 0000000..05bcb74 --- /dev/null +++ b/scripts/la-vision/fix_segmentation.py @@ -0,0 +1,129 @@ +import os +from glob import glob + +import napari +import numpy as np +import imageio.v3 as imageio +import vigra + +from skimage.filters import gaussian +from skimage.segmentation import find_boundaries, watershed +from scipy.ndimage import distance_transform_edt +from skimage.feature import peak_local_max +from skimage.measure import regionprops, label + + +def _size_filter(segmentation, heightmap, min_size): + ids, sizes = np.unique(segmentation, return_counts=True) + discard_ids = ids[sizes < min_size] + mask = segmentation > 0 + segmentation[np.isin(segmentation, discard_ids)] = 0 + return watershed(heightmap, markers=segmentation, mask=mask) + + +def postproc(image, segmentation, view=False): + # First get rid of small objects. + min_size = 250 + heightmap = vigra.filters.laplacianOfGaussian(image.astype("float32"), 3) + + segmentation = _size_filter(segmentation, heightmap, min_size) + + mask = ~find_boundaries(segmentation) + dist = distance_transform_edt(mask, sampling=(2, 1, 1)) + dist[segmentation == 0] = 0 + dist = gaussian(dist, (0.6, 1.2, 1.2)) + maxima = peak_local_max(dist, min_distance=3, exclude_border=False) + + maxima_image = np.zeros(segmentation.shape, dtype="uint8") + pos = tuple(maxima[:, i] for i in range(3)) + maxima_image[pos] = 1 + maxima_image = label(maxima_image) + + def maxima_ids(seg, im): + ids = np.unique(im[seg]) + return ids[1:] + + seed_maxima_ids, keep_seg_ids, split_seg_ids = [], [], [] + props = regionprops(segmentation, maxima_image, extra_properties=[maxima_ids]) + for prop in props: + this_maxima_ids = prop.maxima_ids + if len(this_maxima_ids) == 1: + keep_seg_ids.append(prop.label) + continue + seed_maxima_ids.extend(this_maxima_ids.tolist()) + split_seg_ids.append(prop.label) + + split_mask = np.isin(segmentation, split_seg_ids) + # segmentation[split_mask] = 0 + + new_seeds = maxima_image.copy() + new_seeds[~np.isin(maxima_image, seed_maxima_ids)] = 0 + new_seg = watershed(heightmap, markers=new_seeds, mask=split_mask) + + segmentation[split_mask] = 0 + offset = segmentation.max() + new_seg[new_seg != 0] += offset + segmentation[split_mask] = new_seg[split_mask] + segmentation = label(segmentation) + segmentation = _size_filter(segmentation, heightmap, min_size) + + if view: + v = napari.Viewer() + v.add_image(image) + v.add_labels(segmentation) + # v.add_labels(new_seg) + # v.add_image(heightmap) + # v.add_image(dist) + # v.add_points(maxima) + # v.add_labels(split_mask) + napari.run() + + return segmentation + + +def postprocess_volume(im_path, seg_path, out_root): + image = imageio.imread(im_path) + segmentation = imageio.imread(seg_path) + segmentation = postproc(image, segmentation, view=True) + + os.makedirs(out_root, exist_ok=True) + fname = os.path.basename(im_path) + imageio.imwrite(os.path.join(out_root, fname), segmentation, compression="zlib") + + +def postprocess_volume_scalable(im_path, seg_path, out_root): + from flamingo_tools.segmentation.postprocessing import split_nonconvex_objects, compute_table_on_the_fly + + image = imageio.imread(im_path) + segmentation = imageio.imread(seg_path) + + # TODO aniso resolution + resolution = 0.38 + table = compute_table_on_the_fly(segmentation, resolution) + + out = np.zeros_like(segmentation) + id_mapping = split_nonconvex_objects(segmentation, out, table, n_threads=1, resolution=resolution, min_size=250) + n_prev = len(id_mapping) + n_after = sum([len(v) for v in id_mapping.values()]) + print("Before splitting:", n_prev) + print("After splitting:", n_after) + + v = napari.Viewer() + v.add_image(image) + v.add_labels(segmentation, visible=False) + v.add_labels(out) + napari.run() + + +def main(): + im_paths = sorted(glob("la-vision-sgn-new/images/*.tif")) + seg_paths = sorted(glob("la-vision-sgn-new/segmentation/*.tif")) + out_root = "la-vision-sgn-new/segmentation-postprocessed" + for im_path, seg_path in zip(im_paths, seg_paths): + # postprocess_volume(im_path, seg_path, out_root) + postprocess_volume_scalable(im_path, seg_path, out_root) + break + + +if __name__ == "__main__": + main() From 663d82ee427b0445d7c6961342ef70767deb86a9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 8 Sep 2025 22:26:02 +0200 Subject: [PATCH 05/13] Improve train-val splits --- flamingo_tools/training/util.py | 4 ++- scripts/training/train_distance_unet.py | 36 ++++++++++++++++++++----- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/flamingo_tools/training/util.py b/flamingo_tools/training/util.py index 71f7a32..5c326d8 100644 --- a/flamingo_tools/training/util.py +++ b/flamingo_tools/training/util.py @@ -28,6 +28,7 @@ def get_supervised_loader( image_key: Optional[str] = None, label_key: Optional[str] = None, n_samples: Optional[int] = None, + raw_transform: Optional[callable] = None, ) -> DataLoader: """Get a data loader for a supervised segmentation task. @@ -39,6 +40,7 @@ def get_supervised_loader( image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data. image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data. n_samples: The number of samples to use for training. + raw_transform: Optional transformation for the raw data. Returns: The data loader. @@ -52,6 +54,6 @@ def get_supervised_loader( loader = torch_em.default_segmentation_loader( raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key, batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform, - n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler + n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler, raw_transform=raw_transform, ) return loader diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 178bc81..4a109f0 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -1,10 +1,12 @@ import argparse +import json import os from datetime import datetime from glob import glob import torch_em from flamingo_tools.training import get_supervised_loader, get_3d_model +from sklearn.model_selection import train_test_split ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training" @@ -54,7 +56,7 @@ def get_image_and_label_paths_sep_folders(root): return image_paths, label_paths -def select_paths(image_paths, label_paths, split, filter_empty): +def select_paths(image_paths, label_paths, split, filter_empty, random_split=True): if filter_empty: image_paths = [imp for imp in image_paths if "empty" not in imp] label_paths = [imp for imp in label_paths if "empty" not in imp] @@ -64,10 +66,13 @@ def select_paths(image_paths, label_paths, split, filter_empty): train_fraction = 0.85 n_train = int(train_fraction * n_files) - if split == "train": + if split == "train" and random_split: + image_paths, _, label_paths, _ = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42) + elif split == "train": image_paths = image_paths[:n_train] label_paths = label_paths[:n_train] - + elif split == "val" and random_split: + _, image_paths, _, label_paths = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42) elif split == "val": image_paths = image_paths[n_train:] label_paths = label_paths[n_train:] @@ -90,7 +95,11 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_fold elif split == "val": n_samples = 16 * batch_size - return get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples) + return ( + get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples), + this_image_paths, + this_label_paths + ) def main(): @@ -131,10 +140,10 @@ def main(): model = get_3d_model() # Create the training loader with train and val set. - train_loader = get_loader( + train_loader, train_images, train_labels = get_loader( root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders ) - val_loader = get_loader( + val_loader, val_images, val_labels = get_loader( root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders ) @@ -146,8 +155,21 @@ def main(): loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True) - # Create the trainer. + # Serialize the train test split. name = f"cochlea_distance_unet_{run_name}" + ckpt_folder = os.path.join("checkpoints", name) + os.makedirs(ckpt_folder, exist_ok=True) + split_file = os.path.join(ckpt_folder, "split.json") + with open(split_file, "w") as f: + json.dump( + { + "train": {"images": train_images, "labels": train_labels}, + "val": {"images": val_images, "labels": val_labels}, + }, + f, sort_keys=True, indent=2 + ) + + # Create the trainer. trainer = torch_em.default_segmentation_trainer( name=name, model=model, From e698bb1f355a28a2da2be2936f03e8c4244a2006 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 9 Sep 2025 17:35:51 +0200 Subject: [PATCH 06/13] Implement SGN detection --- .../la-vision/la_vision_point_annotations.py | 33 ++++++ scripts/la-vision/train_sgn_detection.py | 111 ++++++++++++++++++ .../detection_dataset.py | 36 +++++- 3 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 scripts/la-vision/la_vision_point_annotations.py create mode 100644 scripts/la-vision/train_sgn_detection.py diff --git a/scripts/la-vision/la_vision_point_annotations.py b/scripts/la-vision/la_vision_point_annotations.py new file mode 100644 index 0000000..4a89e3c --- /dev/null +++ b/scripts/la-vision/la_vision_point_annotations.py @@ -0,0 +1,33 @@ +import os +from glob import glob + +import imageio.v3 as imageio +import napari +import numpy as np +from skimage.measure import regionprops + + +def main(): + image_files = sorted(glob("la-vision-sgn-new/images/*.tif")) + label_files = sorted(glob("la-vision-sgn-new/segmentation-postprocessed/*.tif")) + + for imf, lf in zip(image_files, label_files): + im = imageio.imread(imf) + labels = imageio.imread(lf) + + props = regionprops(labels) + centers = np.array([prop.centroid for prop in props]) + + name = os.path.basename(imf) + print(name) + + v = napari.Viewer() + v.add_image(im) + v.add_labels(labels) + v.add_points(centers, size=5, out_of_slice_display=True) + v.title = name + napari.run() + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/train_sgn_detection.py b/scripts/la-vision/train_sgn_detection.py new file mode 100644 index 0000000..79bf4f8 --- /dev/null +++ b/scripts/la-vision/train_sgn_detection.py @@ -0,0 +1,111 @@ +import os +import sys +import json +from glob import glob + +from sklearn.model_selection import train_test_split + +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") +sys.path.append("../synapse_marker_detection") + +from utils.training.training import supervised_training # noqa +from detection_dataset import DetectionDataset, MinPointSampler # noqa + +ROOT = "./la-vision-sgn-new" # noqa + +TRAIN = os.path.join(ROOT, "images") +TRAIN_EMPTY = os.path.join(ROOT, "empty_images") + +LABEL = os.path.join(ROOT, "centroids") +LABEL_EMPTY = os.path.join(ROOT, "empty_centroids") + + +def _get_paths(split, train_folder, label_folder, n=None): + image_paths = sorted(glob(os.path.join(train_folder, "*.tif"))) + label_paths = sorted(glob(os.path.join(label_folder, "*.csv"))) + assert len(image_paths) == len(label_paths) + if n is not None: + image_paths, label_paths = image_paths[:n], label_paths[:n] + + train_images, val_images, train_labels, val_labels = train_test_split( + image_paths, label_paths, test_size=1, random_state=42 + ) + + if split == "train": + image_paths = train_images + label_paths = train_labels + else: + image_paths = val_images + label_paths = val_labels + + return image_paths, label_paths + + +def get_paths(split): + image_paths, label_paths = _get_paths(split, TRAIN, LABEL) + empty_image_paths, empty_label_paths = _get_paths(split, TRAIN_EMPTY, LABEL_EMPTY, n=4) + return image_paths + empty_image_paths, label_paths + empty_label_paths + + +def train(): + + model_name = "sgn-low-res-detection-v1" + + train_paths, train_label_paths = get_paths("train") + val_paths, val_label_paths = get_paths("val") + # We need to give the paths for the test loader, although it's never used. + test_paths, test_label_paths = val_paths, val_label_paths + + print("Start training with:") + print(len(train_paths), "tomograms for training") + print(len(val_paths), "tomograms for validation") + + patch_shape = [48, 256, 256] + batch_size = 8 + check = False + + checkpoint_path = f"./checkpoints/{model_name}" + os.makedirs(checkpoint_path, exist_ok=True) + with open(os.path.join(checkpoint_path, "splits.json"), "w") as f: + json.dump( + { + "train": {"images": train_paths, "labels": train_label_paths}, + "val": {"images": val_paths, "labels": val_label_paths}, + }, + f, indent=2, sort_keys=True + ) + + supervised_training( + name=model_name, + train_paths=train_paths, + train_label_paths=train_label_paths, + val_paths=val_paths, + val_label_paths=val_label_paths, + raw_key=None, + patch_shape=patch_shape, batch_size=batch_size, + check=check, + lr=1e-4, + n_iterations=int(1e5), + out_channels=1, + augmentations=None, + eps=1e-5, + sigma=4, + lower_bound=None, + upper_bound=None, + test_paths=test_paths, + test_label_paths=test_label_paths, + # save_root="", + dataset_class=DetectionDataset, + n_samples_train=3200, + n_samples_val=160, + sampler=MinPointSampler(min_points=1, p_reject=0.5), + ) + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py index ae9b361..ab4fca7 100644 --- a/scripts/synapse_marker_detection/detection_dataset.py +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -1,3 +1,4 @@ +import imageio.v3 as imageio import numpy as np import pandas as pd import torch @@ -38,7 +39,6 @@ def __call__(self, x: np.ndarray, n_points: int) -> bool: def load_labels(label_path, shape, bb): points = pd.read_csv(label_path) - assert len(points.columns) == len(shape) z_coords, y_coords, x_coords = points["axis-0"].values, points["axis-1"].values, points["axis-2"].values if bb is not None: @@ -85,6 +85,25 @@ def process_labels(coords, shape, sigma, eps, bb=None): return labels +def process_labels_hacky(coords, shape, sigma, eps, bb=None): + + if bb: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) + labels = np.zeros(restricted_shape, dtype="float32") + shape = restricted_shape + else: + labels = np.zeros(shape, dtype="float32") + + labels[coords] = 1 + labels = gaussian(labels, sigma) + labels = labels.clip(0, 0.0075) + labels /= (labels.max() + 1e-7) + labels *= 4 + labels = labels.clip(0, 1) + return labels + + class DetectionDataset(torch.utils.data.Dataset): max_sampling_attempts = 500 @@ -132,8 +151,8 @@ def __init__( self.eps = eps self.sigma = sigma - with zarr.open(self.raw_path, "r") as f: - self.shape = f[self.raw_key].shape + self.raw = imageio.imread(self.raw_path) if raw_key is None else zarr.open(self.raw_path, "r")[raw_key][:] + self.shape = self.raw.shape if n_samples is None: self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples @@ -159,9 +178,8 @@ def _sample_bounding_box(self, shape): return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) def _get_sample(self, index): - raw, label_path = self.raw_path, self.label_path + raw, label_path = self.raw, self.label_path - raw = zarr.open(raw)[self.raw_key] have_raw_channels = raw.ndim == 4 # 3D with channels shape = raw.shape @@ -187,7 +205,13 @@ def _get_sample(self, index): if sample_id > self.max_sampling_attempts: raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") - label = process_labels(coords, shape, self.sigma, self.eps, bb=bb) + # For synapse detection. + # label = process_labels(coords, shape, self.sigma, self.eps, bb=bb) + + # For SGN detection with data specfic hacks + label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb) + gap = 6 + raw_patch, label = raw_patch[gap:-gap], label[gap:-gap] have_label_channels = label.ndim == 4 if have_label_channels: From 5d6955e6f32780f32e8391714e9bb97d6e17cd89 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 11 Sep 2025 10:15:39 +0200 Subject: [PATCH 07/13] Implement la-vision WS prototype --- scripts/la-vision/watershed_prototype.py | 70 ++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 scripts/la-vision/watershed_prototype.py diff --git a/scripts/la-vision/watershed_prototype.py b/scripts/la-vision/watershed_prototype.py new file mode 100644 index 0000000..2d19e49 --- /dev/null +++ b/scripts/la-vision/watershed_prototype.py @@ -0,0 +1,70 @@ +import os + +import imageio.v3 as imageio +import napari +import numpy as np +import pandas as pd + +from scipy.ndimage import distance_transform_edt +from skimage.measure import label +from skimage.segmentation import watershed + + +def simple_watershed(im, det, radius=8): + """Use a simple watershed to create speheres. + """ + + # Compute the distance to the detctions. + seeds = np.zeros(im.shape, dtype="uint8") + det_idx = tuple(det[ax].values for ax in ["axis-0", "axis-1", "axis-2"]) + seeds[det_idx] = 1 + distances = distance_transform_edt(seeds == 0, sampling=(3.0, 1.887779, 1.887779)) + seeds = label(seeds) + + mask = distances < radius + return watershed(distances, seeds, mask=mask), distances, seeds + + +def complex_watershed(im, det, pred, radius=8): + """More complex waterhsed in combination with network predictions. + + WIP: this does not work well yet. + """ + fg_pred = pred[0] + # bd_pred = pred[2] + + _, seeds, distances = simple_watershed(im, det, radius=radius) + + # Ensure everything within five 8 micron of a center is foreground + fg = np.logical_or(fg_pred > 0.5, distances > radius) + + # TODO find a good hmap! + hmap = distances + + # Watershed. + seg = watershed(hmap, markers=seeds, mask=fg, compactness=5) + return seg, distances, seeds + + +def main(): + root = "la-vision-sgn-new/detections-v1" + im = imageio.imread(os.path.join(root, "LaVision-M04_crop_2580-2266-0533_PV.tif")) + det = pd.read_csv(os.path.join(root, "LaVision-M04_crop_2580-2266-0533_PV.csv")) + # pred = imageio.imread(os.path.join(root, "LaVision-M04_crop_2580-2266-0533_PRED.tif")) + + seg, distances, seeds = simple_watershed(im, det, radius=12) + # This does not yet work well. + # seg, distances, seeds = complex_watershed(im, det, pred) + + v = napari.Viewer() + v.add_image(im) + v.add_image(distances, visible=False) + v.add_labels(seeds, visible=False) + # v.add_image(pred, visible=False) + v.add_points(det, visible=False) + v.add_labels(seg) + napari.run() + + +if __name__ == "__main__": + main() From 66ed39bd60ecbd67850b033fe5bc3935d2df9d37 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 11 Sep 2025 10:19:27 +0200 Subject: [PATCH 08/13] Update to sgn detection training --- flamingo_tools/segmentation/postprocessing.py | 26 ++++++++++++++++--- scripts/la-vision/train_sgn_detection.py | 4 ++- .../detection_dataset.py | 8 +++--- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index edf8b5a..b363075 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -774,7 +774,8 @@ def split_object(object_id): nonlocal offset row = segmentation_table[segmentation_table.label_id == object_id] - if min_size and row.n_pixels.values[0] < min_size: + if row.n_pixels.values[0] < min_size: + # print(object_id, ": min-size") return [object_id] bb_min = np.array([ @@ -788,6 +789,12 @@ def split_object(object_id): bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape))) bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max)) + # This is due to segmentation artifacts. + bb_shape = bb_max - bb_min + if (bb_shape > 500).any(): + print(object_id, "has a too large shape:", bb_shape) + return [object_id] + seg = segmentation[bb] mask = ~find_boundaries(seg) dist = distance_transform_edt(mask, sampling=resolution) @@ -798,6 +805,7 @@ def split_object(object_id): maxima = peak_local_max(dist, min_distance=3, exclude_border=True) if len(maxima) == 1: + # print(object_id, ": max len") return [object_id] with lock: @@ -819,14 +827,20 @@ def split_object(object_id): keep_ids = seg_ids[sizes > min_size] if len(keep_ids) < 2: + # print(object_id, ": keep-id") return [object_id] elif len(keep_ids) != len(seg_ids): new_seg[~np.isin(new_seg, keep_ids)] = 0 new_seg = watershed(hmap, markers=new_seg, mask=seg_mask) - output[bb][seg_mask] = new_seg[seg_mask] - return seg_ids.tolist() + with lock: + out = output[bb] + out[seg_mask] = new_seg[seg_mask] + output[bb] = out + + # print(object_id, ":", len(keep_ids)) + return keep_ids.tolist() # import napari # v = napari.Viewer() @@ -839,11 +853,15 @@ def split_object(object_id): if component_labels is None: object_ids = segmentation_table.label_id.values else: - object_ids = segmentation_table[segmentation_table.isin(component_labels)].label_id.values + object_ids = segmentation_table[segmentation_table.component_labels.isin(component_labels)].label_id.values if n_threads is None: n_threads = mp.cpu_count() + # new_id_mapping = [] + # for object_id in tqdm(object_ids, desc="Split non-convex objects"): + # new_id_mapping.append(split_object(object_id)) + with futures.ThreadPoolExecutor(n_threads) as tp: new_id_mapping = list( tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects") diff --git a/scripts/la-vision/train_sgn_detection.py b/scripts/la-vision/train_sgn_detection.py index 79bf4f8..39075b7 100644 --- a/scripts/la-vision/train_sgn_detection.py +++ b/scripts/la-vision/train_sgn_detection.py @@ -12,7 +12,8 @@ from utils.training.training import supervised_training # noqa from detection_dataset import DetectionDataset, MinPointSampler # noqa -ROOT = "./la-vision-sgn-new" # noqa +# ROOT = "./la-vision-sgn-new" # noqa +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection" TRAIN = os.path.join(ROOT, "images") TRAIN_EMPTY = os.path.join(ROOT, "empty_images") @@ -24,6 +25,7 @@ def _get_paths(split, train_folder, label_folder, n=None): image_paths = sorted(glob(os.path.join(train_folder, "*.tif"))) label_paths = sorted(glob(os.path.join(label_folder, "*.csv"))) + assert len(image_paths) > 0 assert len(image_paths) == len(label_paths) if n is not None: image_paths, label_paths = image_paths[:n], label_paths[:n] diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py index ab4fca7..21b1efc 100644 --- a/scripts/synapse_marker_detection/detection_dataset.py +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -206,12 +206,12 @@ def _get_sample(self, index): raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") # For synapse detection. - # label = process_labels(coords, shape, self.sigma, self.eps, bb=bb) + label = process_labels(coords, shape, self.sigma, self.eps, bb=bb) # For SGN detection with data specfic hacks - label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb) - gap = 6 - raw_patch, label = raw_patch[gap:-gap], label[gap:-gap] + # label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb) + # gap = 8 + # raw_patch, label = raw_patch[gap:-gap], label[gap:-gap] have_label_channels = label.ndim == 4 if have_label_channels: From 7507d9146a288a7fb867c1ded6807a204e9526ad Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 11 Sep 2025 21:21:35 +0200 Subject: [PATCH 09/13] Implement more debugging for detection model --- scripts/la-vision/check_detections.py | 61 ++++++++++++++++++++++++++ scripts/la-vision/debug_prediction.py | 62 +++++++++++++++++++++++++++ scripts/la-vision/detect_blocks.py | 34 +++++++++++++++ scripts/la-vision/export_model.py | 26 +++++++++++ 4 files changed, 183 insertions(+) create mode 100644 scripts/la-vision/check_detections.py create mode 100644 scripts/la-vision/debug_prediction.py create mode 100644 scripts/la-vision/detect_blocks.py create mode 100644 scripts/la-vision/export_model.py diff --git a/scripts/la-vision/check_detections.py b/scripts/la-vision/check_detections.py new file mode 100644 index 0000000..a7b3d44 --- /dev/null +++ b/scripts/la-vision/check_detections.py @@ -0,0 +1,61 @@ +import napari +import zarr + + +resolution = [3.0, 1.887779, 1.887779] +positions = [ + [2002.95539395823, 1899.9032205156411, 264.7747008147759] +] + + +def _load_from_mobie(bb): + path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr" + f = zarr.open(path, mode="r") + data = f["s0"][bb] + print(bb) + + path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/SGN_detect-v1.ome.zarr" + f = zarr.open(path, mode="r") + seg = f["s0"][bb] + + return data, seg + + +def _load_prediction(bb): + path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/LaVision-M04/SGN_detect-v1/predictions.zarr" + f = zarr.open(path, mode="r") + data = f["prediction"][bb] + return data + + +def _load_prediction_debug(): + path = "./debug-pred/pred-v5.h5" + with zarr.open(path, "r") as f: + pred = f["pred"][:] + return pred + + +def check_detection(position, halo=[32, 384, 384]): + + bb = tuple( + slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo) + ) + + pv, detections_mobie = _load_from_mobie(bb) + # pred = _load_prediction(bb) + pred = _load_prediction_debug() + + v = napari.Viewer() + v.add_image(pv) + v.add_image(pred) + v.add_labels(detections_mobie) + napari.run() + + +def main(): + position = positions[0] + check_detection(position) + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/debug_prediction.py b/scripts/la-vision/debug_prediction.py new file mode 100644 index 0000000..f84442b --- /dev/null +++ b/scripts/la-vision/debug_prediction.py @@ -0,0 +1,62 @@ +import os +from functools import partial + +import numpy as np +import torch +import zarr +from torch_em.transform.raw import standardize +from torch_em.util.prediction import predict_with_halo + + +resolution = [3.0, 1.887779, 1.887779] +positions = [ + [2002.95539395823, 1899.9032205156411, 264.7747008147759] +] + + +def _load_from_mobie(bb): + path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr" + f = zarr.open(path, mode="r") + data = f["s0"][bb] + return data + + +def run_prediction(position, halo=[32, 384, 384]): + bb = tuple( + slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo) + ) + pv = _load_from_mobie(bb) + mean, std = np.mean(pv), np.std(pv) + print(mean, std) + preproc = partial(standardize, mean=mean, std=std) + + block_shape = (24, 256, 256) + halo = (8, 64, 64) + + model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-v1.pt" + model = torch.load(model_path, weights_only=False) + + def postproc(x): + x = np.clip(x, 0, 1) + max_ = np.percentile(x, 99) + x = x / max_ + return x + + pred = predict_with_halo(pv, model, [0], block_shape, halo, preprocess=preproc, postprocess=postproc).squeeze() + + pred_name = "pred-v5" + out_folder = "./debug-pred" + os.makedirs(out_folder, exist_ok=True) + + out_path = os.path.join(out_folder, f"{pred_name}.h5") + with zarr.open(out_path, "w") as f: + f.create_dataset("pred", data=pred) + + +def main(): + position = positions[0] + run_prediction(position) + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/detect_blocks.py b/scripts/la-vision/detect_blocks.py new file mode 100644 index 0000000..2e947cb --- /dev/null +++ b/scripts/la-vision/detect_blocks.py @@ -0,0 +1,34 @@ +import os +import imageio.v3 as imageio +from pathlib import Path + +import pandas as pd +import torch +from skimage.feature import peak_local_max +from torch_em.util.prediction import predict_with_halo + +ims = [ + "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection/images/LaVision-M04_crop_2580-2266-0533_PV.tif", + "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection/empty_images/LaVision-M04_crop_0400-2500-0840_PV_empty.tif" +] + +model_path = "checkpoints/sgn-detection.pt" +model = torch.load(model_path, weights_only=False) + +block_shape = [24, 256, 256] +halo = (8, 64, 64) + +out = "./detections-v1" +os.makedirs(out, exist_ok=True) +for im in ims: + data = imageio.imread(im) + pred = predict_with_halo(data, model, [0], block_shape, halo).squeeze() + + coords = peak_local_max(pred, min_distance=4, threshold_abs=0.5) + + # coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1) + coords = pd.DataFrame(coords, columns=["axis-0", "axis-1", "axis-2"]) + + name = Path(im).stem + imageio.imwrite(os.path.join(out, f"{name}.tif"), pred) + coords.to_csv(os.path.join(out, f"{name}.csv"), index=False) diff --git a/scripts/la-vision/export_model.py b/scripts/la-vision/export_model.py new file mode 100644 index 0000000..edd6c53 --- /dev/null +++ b/scripts/la-vision/export_model.py @@ -0,0 +1,26 @@ +import argparse +import sys + +import torch +from torch_em.util import load_model + +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") +sys.path.append("../synapse_marker_detection") + + +def export_model(input_, output): + model = load_model(input_, device="cpu") + torch.save(model, output) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-o", "--output", required=True) + args = parser.parse_args() + export_model(args.input, args.output) + + +if __name__ == "__main__": + main() From 8c68d190432de4203b9b560f055795deaf4b1723 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 12 Sep 2025 08:11:59 +0200 Subject: [PATCH 10/13] Add otof import scripts --- scripts/la-vision/otof/import_cochlea.py | 46 +++++++++++++++++++++++ scripts/la-vision/otof/upload_otof_lsm.sh | 11 ++++++ 2 files changed, 57 insertions(+) create mode 100644 scripts/la-vision/otof/import_cochlea.py create mode 100755 scripts/la-vision/otof/upload_otof_lsm.sh diff --git a/scripts/la-vision/otof/import_cochlea.py b/scripts/la-vision/otof/import_cochlea.py new file mode 100644 index 0000000..a338d99 --- /dev/null +++ b/scripts/la-vision/otof/import_cochlea.py @@ -0,0 +1,46 @@ +import os +import h5py +from mobie import add_image +from mobie.metadata import read_dataset_metadata + +INPUT_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LA_VISION_OTOF/Test_FreeRotate_0-40-59_PRO82_OtofKO-23R_p24_chCR-488_rbOtof-647_UltraII_C00_xyz.ims" # noqa +MOBIE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" +DS_NAME = "LaVision-OTOF" +RESOLUTION = (3.0, 1.887779, 1.887779) + + +# Channels: "chCR-488_rbOtof-647" +def add_otof(): + channel_names = ("CR", "rbOtof") + channel_keys = [ + "/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data", + "/DataSet/ResolutionLevel 0/TimePoint 0/Channel 1/Data" + ] + + scale_factors = 4 * [[2, 2, 2]] + chunks = (96, 96, 96) + + for channel_key, channel_name in zip(channel_keys, channel_names): + mobie_ds_folder = os.path.join(MOBIE_ROOT, DS_NAME) + ds_metadata = read_dataset_metadata(mobie_ds_folder) + if channel_name in ds_metadata.get("sources", {}): + print(channel_name, "is already in MoBIE") + continue + + print("Load image data ...") + with h5py.File(INPUT_PATH, "r") as f: + input_data = f[channel_key][:] + print(input_data.shape) + add_image( + input_path=input_data, input_key=None, root=MOBIE_ROOT, + dataset_name=DS_NAME, image_name=channel_name, resolution=RESOLUTION, + scale_factors=scale_factors, chunks=chunks, unit="micrometer", use_memmap=False, + ) + + +def main(): + add_otof() + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/otof/upload_otof_lsm.sh b/scripts/la-vision/otof/upload_otof_lsm.sh new file mode 100755 index 0000000..f71ca95 --- /dev/null +++ b/scripts/la-vision/otof/upload_otof_lsm.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet +COCHLEA=LaVision-OTOF + +export BUCKET_NAME="cochlea-lightsheet" +export SERVICE_ENDPOINT="https://s3.fs.gwdg.de" +mobie.add_remote_metadata -i $MOBIE_DIR -s $SERVICE_ENDPOINT -b $BUCKET_NAME + +# rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA" cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA" +# rclone --progress copyto "$MOBIE_DIR"/project.json cochlea-lightsheet:cochlea-lightsheet/project.json From 49fb49f4219e00b11c1823752b697952bd83c988 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 16 Sep 2025 16:17:34 +0200 Subject: [PATCH 11/13] Update SGN Subtype analysis --- scripts/figures/util.py | 15 ++- scripts/measurements/sgn_subtypes.py | 176 +++++++++++++++++++-------- 2 files changed, 136 insertions(+), 55 deletions(-) diff --git a/scripts/figures/util.py b/scripts/figures/util.py index 489fa95..42e5532 100644 --- a/scripts/figures/util.py +++ b/scripts/figures/util.py @@ -24,7 +24,7 @@ def _get_mapping(animal): return bin_edges, bin_labels -def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False): +def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False, categorical=False): # Get the mapping of frequencies to octave bands for the given species. bin_edges, bin_labels = _get_mapping(animal) @@ -34,7 +34,18 @@ def frequency_mapping(frequencies, values, animal="mouse", transduction_efficien df["freq_khz"], bins=bin_edges, labels=bin_labels, right=False ) - if transduction_efficiency: # We compute the transduction efficiency per band. + if categorical: + assert not transduction_efficiency + categories = pd.unique(df.value) + num_tot = df.groupby("octave_band", observed=False).size() + value_by_band = {} + for cat in categories: + pos_cat = df[df.value == cat].groupby("octave_band", observed=False).size() + cat_by_band = (pos_cat / num_tot).reindex(bin_labels) + cat_by_band = cat_by_band.reset_index() + cat_by_band.columns = ["octave_band", "value"] + value_by_band[cat] = cat_by_band + elif transduction_efficiency: # We compute the transduction efficiency per band. num_pos = df[df["value"] == 1].groupby("octave_band", observed=False).size() num_tot = df[df["value"].isin([1, 2])].groupby("octave_band", observed=False).size() value_by_band = (num_pos / num_tot).reindex(bin_labels) diff --git a/scripts/measurements/sgn_subtypes.py b/scripts/measurements/sgn_subtypes.py index 9449d25..90c9dfc 100644 --- a/scripts/measurements/sgn_subtypes.py +++ b/scripts/measurements/sgn_subtypes.py @@ -1,5 +1,6 @@ import json import os +import sys from glob import glob from subprocess import run @@ -10,6 +11,8 @@ from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path from flamingo_tools.measurements import compute_object_measures +sys.path.append("../figures") + # Map from cochlea names to channels COCHLEAE_FOR_SUBTYPES = { "M_LR_000099_L": ["PV", "Calb1", "Lypd1"], @@ -17,11 +20,16 @@ "M_AMD_N62_L": ["PV", "CR", "Calb1"], "M_AMD_N180_R": ["CR", "Ntng1", "CTBP2"], "M_AMD_N180_L": ["CR", "Ntng1", "Lypd1"], + "M_LR_000184_R": ["PV", "Prph"], + "M_LR_000184_L": ["PV", "Prph"], # Mutant / some stuff is weird. # "M_AMD_Runx1_L": ["PV", "Lypd1", "Calb1"], # This one still has to be stitched: # "M_LR_000184_R": {"PV", "Prph"}, } +REGULAR_COCHLEAE = [ + "M_LR_000099_L", "M_LR_000214_L", "M_AMD_N62_L", "M_LR_000184_R", "M_LR_000184_L" +] # Map from channels to subtypes. # Comment Aleyna: @@ -46,8 +54,31 @@ }, } +# For consistent colors. +ALL_COLORS = ["red", "blue", "orange", "yellow", "cyan", "magenta", "green", "purple"] +COLORS = {} + PLOT_OUT = "./subtype_plots" +# TODO: updates based on Aleyna's feedback. +# Subtype mapping + +# Combined visualization for the cochleae +# Can we visualize the tonotopy in subtypes and not stainings? +# It would also be good to have subtype percentages per cochlea and pooled together as a diagram and tonotopy? +# This would help to see if different staining gives same/similar results. +# Type Ia ; CR+ / Calb1- or Calb1- / Lypd1- +# Type Ib: CR+ / Calb1+ or Calb1+ / Lypd1+ +# Type Ic: CR-/Calb1+ - or Calb1- / Lypd1+ +# Type II: CR-/Calb1- or Calb1- / Lypd1- or Prph+ + +# > It's good to see that for the N mice the Ntng1C and Lypd1 separate from CR so well on the thresholds. Can I visualize these samples ones segmentation masks are done to verify the Ntng1C thresholds? As this is a quite clear signal I'm not sure if taking the middle of the histogram would be the best choice. +# The segmentations are in MoBIE already. I need to send you the tables for analyzing the signals. Will send them later. + +# > Where are we at PV-Prph segmentation results from MLR184_L and R for SGN type II analysis? This would hopefully give <5% Prph+ cells. +# The cochleae are in MoBIE. Segmentation and Prph signal look good! I will include it in the next analysis. +# Need tonotopic mapping from Martin and then compute the intensities. + def check_processing_status(): s3 = create_s3_target() @@ -63,6 +94,8 @@ def check_processing_status(): missing_tables = {} for cochlea, channels in COCHLEAE_FOR_SUBTYPES.items(): + if cochlea not in REGULAR_COCHLEAE: + continue try: content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") except FileNotFoundError: @@ -125,6 +158,8 @@ def require_missing_tables(missing_tables): output_root = "./object_measurements" for cochlea, missing_tabs in missing_tables.items(): + if cochlea not in REGULAR_COCHLEAE: + continue for missing in missing_tabs: channel = missing.split("_")[0] seg_name = missing.split("_")[1].replace("-", "_") @@ -165,6 +200,8 @@ def compile_data_for_subtype_analysis(): os.makedirs(output_folder, exist_ok=True) for cochlea, channels in COCHLEAE_FOR_SUBTYPES.items(): + if cochlea not in REGULAR_COCHLEAE: + continue if "PV" in channels: reference_channel = "PV" seg_name = "PV_SGN_v2" @@ -217,14 +254,17 @@ def compile_data_for_subtype_analysis(): output_table.to_csv(out_path, sep="\t", index=False) -def _plot_histogram(table, column, name, show_plots, subtype=None): +def _plot_histogram(table, column, name, show_plots, class_names=None, apply_threshold=True): data = table[column].values threshold = threshold_otsu(data) fig, ax = plt.subplots(1) ax.hist(data, bins=24) - ax.axvline(x=threshold, color='red', linestyle='--') - ax.set_title(f"{name}\n threshold: {threshold}") + if apply_threshold: + ax.axvline(x=threshold, color='red', linestyle='--') + ax.set_title(f"{name}\n threshold: {threshold}") + else: + ax.set_title(name) if show_plots: plt.show() @@ -232,12 +272,14 @@ def _plot_histogram(table, column, name, show_plots, subtype=None): os.makedirs(PLOT_OUT, exist_ok=True) plt.savefig(f"{PLOT_OUT}/{name}.png") - if subtype is not None: - subtype_classification = [None if datum < threshold else subtype for datum in data] + if class_names is not None: + assert len(class_names) == 2 + c0, c1 = class_names + subtype_classification = [c0 if datum < threshold else c1 for datum in data] return subtype_classification -def _plot_2d(ratios, name, show_plots, classification=None): +def _plot_2d(ratios, name, show_plots, classification=None, colors=None): fig, ax = plt.subplots(1) assert len(ratios) == 2 keys = list(ratios.keys()) @@ -247,42 +289,16 @@ def _plot_2d(ratios, name, show_plots, classification=None): ax.scatter(ratios[k1, k2]) else: - def _combine(a, b): - if a is None and b is None: - return None - elif a is None and b is not None: - return b - elif a is not None and b is None: - return a - else: - return f"{a}-{b}" - - classification = [cls for cls in classification if cls is not None] - labels = classification[0].copy() - for cls in classification[1:]: - if cls is None: - continue - labels = [_combine(a, b) for a, b in zip(labels, cls)] - - unique_labels = set(ll for ll in labels if ll is not None) - all_colors = ["red", "blue", "orange", "yellow"] - colors = {ll: color for ll, color in zip(unique_labels, all_colors[:len(unique_labels)])} - + assert colors is not None + unique_labels = set(classification) for lbl in unique_labels: - mask = [ll == lbl for ll in labels] + mask = [ll == lbl for ll in classification] ax.scatter( - [ratios[k1][i] for i in range(len(labels)) if mask[i]], - [ratios[k2][i] for i in range(len(labels)) if mask[i]], + [ratios[k1][i] for i in range(len(classification)) if mask[i]], + [ratios[k2][i] for i in range(len(classification)) if mask[i]], c=colors[lbl], label=lbl ) - mask_none = [ll is None for ll in labels] - ax.scatter( - [ratios[k1][i] for i in range(len(labels)) if mask_none[i]], - [ratios[k2][i] for i in range(len(labels)) if mask_none[i]], - facecolors="none", edgecolors="black", label="None" - ) - ax.legend() ax.set_xlabel(k1) @@ -296,41 +312,90 @@ def _combine(a, b): plt.savefig(f"{PLOT_OUT}/{name}.png") -# TODO enable over-writing by manual thresholds -def analyze_subtype_data(show_plots=True): +def _plot_tonotopic_mapping(freq, classification, name, colors, show_plots): + from util import frequency_mapping + + frequency_mapped = frequency_mapping(freq, classification, categorical=True) + result = next(iter(frequency_mapped.values())) + bin_labels = pd.unique(result["octave_band"]) + band_to_x = {band: i for i, band in enumerate(bin_labels)} + x_positions = result["octave_band"].map(band_to_x) + + fig, ax = plt.subplots(figsize=(8, 4)) + for cat, vals in frequency_mapped.items(): + ax.scatter(x_positions, vals.value, label=cat, color=colors[cat]) + ax.legend() + ax.set_title(name) + + if show_plots: + plt.show() + else: + os.makedirs(PLOT_OUT, exist_ok=True) + plt.savefig(f"{PLOT_OUT}/{name}.png") + plt.close() + + +def analyze_subtype_data_regular(show_plots=True): + global PLOT_OUT, COLORS # noqa + PLOT_OUT = "subtype_plots/regular_mice" + files = sorted(glob("./subtype_analysis/*.tsv")) for ff in files: cochlea = os.path.basename(ff)[:-len("_subtype_analysis.tsv")] + if cochlea not in REGULAR_COCHLEAE: + continue print(cochlea) channels = COCHLEAE_FOR_SUBTYPES[cochlea] - reference_channel = "PV" if "PV" in channels else "CR" + + reference_channel = "PV" assert channels[0] == reference_channel tab = pd.read_csv(ff, sep="\t") # 1.) Plot simple intensity histograms, including otsu threshold. - for chan in channels: - column = f"{chan}_median" - name = f"{cochlea}_{chan}_histogram" - _plot_histogram(tab, column, name, show_plots) + # for chan in channels: + # column = f"{chan}_median" + # name = f"{cochlea}_{chan}_histogram" + # _plot_histogram(tab, column, name, show_plots, apply_threshold=chan != reference_channel) # 2.) Plot ratio histograms, including otsu threshold. - # TODO ratio based classification and overlay in 2d plot? ratios = {} - subtype_classification = [] + classification = [] for chan in channels[1:]: column = f"{chan}_ratio_{reference_channel}" name = f"{cochlea}_{chan}_histogram_ratio_{reference_channel}" - classification = _plot_histogram( - tab, column, name, subtype=CHANNEL_TO_TYPE.get(chan, None), show_plots=show_plots + chan_classification = _plot_histogram( + tab, column, name, class_names=[f"{chan}-", f"{chan}+"], show_plots=show_plots ) - subtype_classification.append(classification) + classification.append(chan_classification) ratios[f"{chan}_{reference_channel}"] = tab[column].values - # 3.) Plot 2D space of ratios. + # Unify the classification and assign colors + cls1, cls2 = classification[0], classification[1] + assert len(cls1) == len(cls2) + classification = [f"{c1} / {c2}" for c1, c2 in zip(cls1, cls2)] + + unique_labels = set(classification) + for label in unique_labels: + if label in COLORS: + continue + if COLORS: + last_color = list(COLORS.values())[-1] + next_color = ALL_COLORS[ALL_COLORS.index(last_color) + 1] + COLORS[label] = next_color + else: + COLORS[label] = ALL_COLORS[0] + + # 3.) Plot tonotopic mapping. + freq = tab["frequency[kHz]"].values + assert len(freq) == len(classification) + name = f"{cochlea}_tonotopic_mapping" + _plot_tonotopic_mapping(freq, classification, name=name, colors=COLORS, show_plots=show_plots) + + # 4.) Plot 2D space of ratios. name = f"{cochlea}_2d" - _plot_2d(ratios, name, show_plots, classification=subtype_classification) + _plot_2d(ratios, name, show_plots, classification=classification, colors=COLORS) # General notes: @@ -338,10 +403,15 @@ def analyze_subtype_data(show_plots=True): def main(): missing_tables = check_processing_status() require_missing_tables(missing_tables) + compile_data_for_subtype_analysis() + + # analyze_subtype_data_regular(show_plots=False) - # compile_data_for_subtype_analysis() + # TODO + # analyze_subtype_data_N_mice() - # analyze_subtype_data(show_plots=False) + # CTBP2 stain + # analyze_subtype_data_syn_mice() if __name__ == "__main__": From 94774a0d059e94d670a19a0e333481d0e256a7e6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 17 Sep 2025 12:42:08 +0200 Subject: [PATCH 12/13] Update subtype analysis --- scripts/measurements/sgn_subtypes.py | 256 +++++++++++++++++++++------ 1 file changed, 199 insertions(+), 57 deletions(-) diff --git a/scripts/measurements/sgn_subtypes.py b/scripts/measurements/sgn_subtypes.py index 90c9dfc..64c2e7f 100644 --- a/scripts/measurements/sgn_subtypes.py +++ b/scripts/measurements/sgn_subtypes.py @@ -2,7 +2,6 @@ import os import sys from glob import glob -from subprocess import run import matplotlib.pyplot as plt import pandas as pd @@ -31,21 +30,6 @@ "M_LR_000099_L", "M_LR_000214_L", "M_AMD_N62_L", "M_LR_000184_R", "M_LR_000184_L" ] -# Map from channels to subtypes. -# Comment Aleyna: -# The signal will be a gradient between different subtypes: -# For example CR is expressed more, is brigther, -# in type 1a SGNs but exist in type Ib SGNs and to a lesser extent in type 1c. -# Same is also true for other markers so we will need to set a threshold for each. -# Luckily the signal seems less variable compared to GFP. -CHANNEL_TO_TYPE = { - "CR": "Type-Ia", - "Calb1": "Type-Ib", - "Lypd1": "Type-Ic", - "Prph": "Type-II", - "Ntng1": "Type-Ib/c", -} - # For custom thresholds. THRESHOLDS = { "M_LR_000214_L": { @@ -55,29 +39,50 @@ } # For consistent colors. -ALL_COLORS = ["red", "blue", "orange", "yellow", "cyan", "magenta", "green", "purple"] +ALL_COLORS = ["red", "blue", "orange", "yellow", "cyan", "magenta", "green", "purple", "gray", "black"] COLORS = {} PLOT_OUT = "./subtype_plots" -# TODO: updates based on Aleyna's feedback. -# Subtype mapping -# Combined visualization for the cochleae -# Can we visualize the tonotopy in subtypes and not stainings? -# It would also be good to have subtype percentages per cochlea and pooled together as a diagram and tonotopy? -# This would help to see if different staining gives same/similar results. # Type Ia ; CR+ / Calb1- or Calb1- / Lypd1- # Type Ib: CR+ / Calb1+ or Calb1+ / Lypd1+ # Type Ic: CR-/Calb1+ - or Calb1- / Lypd1+ # Type II: CR-/Calb1- or Calb1- / Lypd1- or Prph+ +def stain_to_type(stain): + # Normalize the staining string. + stains = stain.replace(" ", "").split("/") + assert len(stains) in (1, 2) -# > It's good to see that for the N mice the Ntng1C and Lypd1 separate from CR so well on the thresholds. Can I visualize these samples ones segmentation masks are done to verify the Ntng1C thresholds? As this is a quite clear signal I'm not sure if taking the middle of the histogram would be the best choice. -# The segmentations are in MoBIE already. I need to send you the tables for analyzing the signals. Will send them later. + if len(stains) == 1: + stain_norm = stain + else: + s1, s2 = sorted(stains) + stain_norm = f"{s1}/{s2}" -# > Where are we at PV-Prph segmentation results from MLR184_L and R for SGN type II analysis? This would hopefully give <5% Prph+ cells. -# The cochleae are in MoBIE. Segmentation and Prph signal look good! I will include it in the next analysis. -# Need tonotopic mapping from Martin and then compute the intensities. + stain_to_type = { + # Combinations of Calb1 and CR: + "CR+/Calb1+": "Type Ib", + "CR-/Calb1+": "Type Ib/Ic", # Calb1 is expressed at Ic less than Lypd1 but more then CR + "CR+/Calb1-": "Type Ia", + "CR-/Calb1-": "Type II", + + # Combinations of Calb1 and Lypd1: + "Calb1+/Lypd1+": "Type Ib/Ic", + "Calb1+/Lypd1-": "Type Ib", + "Calb1-/Lypd1+": "Type Ic", + "Calb1-/Lypd1-": "inconclusive", # Can be Type Ia or Type II + + # Prph is isolated. + "Prph+": "Type II", + "Prph-": "Type I", + } + + if stain_norm not in stain_to_type: + breakpoint() + raise ValueError(f"Invalid stain combination: {stain_norm}") + + return stain_to_type[stain_norm], stain_norm def check_processing_status(): @@ -189,8 +194,9 @@ def require_missing_tables(missing_tables): ) # S3 upload - run(["rclone", "--progress", "copyto", output_folder, - f"cochlea-lightsheet:cochlea-lightsheet/{cochlea}/tables/{seg_name}"]) + # from subprocess import run + # run(["rclone", "--progress", "copyto", output_folder, + # f"cochlea-lightsheet:cochlea-lightsheet/{cochlea}/tables/{seg_name}"]) def compile_data_for_subtype_analysis(): @@ -209,14 +215,20 @@ def compile_data_for_subtype_analysis(): assert "CR" in channels reference_channel = "CR" seg_name = "CR_SGN_v2" - reference_channel, seg_name content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") info = json.loads(content.read()) sources = info["sources"] # Load the segmentation table. - seg_source = sources[seg_name] + try: + seg_source = sources[seg_name] + except KeyError as e: + if seg_name == "PV_SGN_v2": + seg_source = sources["SGN_v2"] + seg_name = "SGN_v2" + else: + raise e table_folder = os.path.join( BUCKET_NAME, cochlea, seg_source["segmentation"]["tableData"]["tsv"]["relativePath"] ) @@ -232,12 +244,19 @@ def compile_data_for_subtype_analysis(): # Analyze the different channels (= different subtypes). reference_intensity = None for channel in channels: - # Load the intensity table. - intensity_path = os.path.join(table_folder, f"{channel}_{seg_name.replace('_', '-')}_object-measures.tsv") - table_content = s3.open(intensity_path, mode="rb") + # Load the intensity table, prefer local. + table_name = f"{channel}_{seg_name.replace('_', '-')}_object-measures.tsv" + intensity_path = os.path.join("object_measurements", cochlea, table_name) + + if os.path.exists(intensity_path): + intensities = pd.read_csv(intensity_path, sep="\t") + else: + intensity_path = os.path.join(table_folder, table_name) + table_content = s3.open(intensity_path, mode="rb") + + intensities = pd.read_csv(table_content, sep="\t") + intensities = intensities[intensities.label_id.isin(valid_sgns)] - intensities = pd.read_csv(table_content, sep="\t") - intensities = intensities[intensities.label_id.isin(valid_sgns)] assert len(table) == len(intensities) assert (intensities.label_id.values == table.label_id.values).all() @@ -258,11 +277,20 @@ def _plot_histogram(table, column, name, show_plots, class_names=None, apply_thr data = table[column].values threshold = threshold_otsu(data) + if class_names is not None: + assert len(class_names) == 2 + c0, c1 = class_names + subtype_classification = [c0 if datum < threshold else c1 for datum in data] + fig, ax = plt.subplots(1) ax.hist(data, bins=24) if apply_threshold: ax.axvline(x=threshold, color='red', linestyle='--') - ax.set_title(f"{name}\n threshold: {threshold}") + if class_names is None: + ax.set_title(f"{name}\n threshold: {threshold}") + else: + pos_perc = len([st for st in subtype_classification if st == c1]) / float(len(subtype_classification)) + ax.set_title(f"{name}\n threshold: {threshold}\n %{c1}: {pos_perc * 100}") else: ax.set_title(name) @@ -271,11 +299,9 @@ def _plot_histogram(table, column, name, show_plots, class_names=None, apply_thr else: os.makedirs(PLOT_OUT, exist_ok=True) plt.savefig(f"{PLOT_OUT}/{name}.png") + plt.close() if class_names is not None: - assert len(class_names) == 2 - c0, c1 = class_names - subtype_classification = [c0 if datum < threshold else c1 for datum in data] return subtype_classification @@ -310,6 +336,7 @@ def _plot_2d(ratios, name, show_plots, classification=None, colors=None): else: os.makedirs(PLOT_OUT, exist_ok=True) plt.savefig(f"{PLOT_OUT}/{name}.png") + plt.close() def _plot_tonotopic_mapping(freq, classification, name, colors, show_plots): @@ -324,6 +351,11 @@ def _plot_tonotopic_mapping(freq, classification, name, colors, show_plots): fig, ax = plt.subplots(figsize=(8, 4)) for cat, vals in frequency_mapped.items(): ax.scatter(x_positions, vals.value, label=cat, color=colors[cat]) + + main_ticks = range(len(bin_labels)) + ax.set_xticks(main_ticks) + ax.set_xticklabels(bin_labels) + ax.set_xlabel("Octave band (kHz)") ax.legend() ax.set_title(name) @@ -334,12 +366,103 @@ def _plot_tonotopic_mapping(freq, classification, name, colors, show_plots): plt.savefig(f"{PLOT_OUT}/{name}.png") plt.close() + return frequency_mapped + + +# Combined visualization for the cochleae +# Can we visualize the tonotopy in subtypes and not stainings? +# It would also be good to have subtype percentages per cochlea and pooled together as a diagram and tonotopy? +# This would help to see if different staining gives same/similar results. +def combined_analysis(results, show_plots): + # + # Create the tonotopic mapping. + # + summary = {} + for cochlea, result in results.items(): + if cochlea == "M_LR_000214_L": # One of the signals cannot be analyzed. + continue + mapping = result["tonotopic_mapping"] + summary[cochlea] = mapping + + colors = {} + + fig, axes = plt.subplots(len(summary), sharey=True, figsize=(8, 8)) + for i, (cochlea, frequency_mapped) in enumerate(summary.items()): + ax = axes[i] + + result = next(iter(frequency_mapped.values())) + bin_labels = pd.unique(result["octave_band"]) + band_to_x = {band: i for i, band in enumerate(bin_labels)} + x_positions = result["octave_band"].map(band_to_x) + + for cat, vals in frequency_mapped.items(): + values = vals.value + cat = cat[:cat.find(" (")] + if cat not in colors: + current_colors = list(colors.values()) + next_color = ALL_COLORS[len(current_colors)] + colors[cat] = next_color + ax.scatter(x_positions, values, label=cat, color=colors[cat]) + + main_ticks = range(len(bin_labels)) + ax.set_xticks(main_ticks) + ax.set_xticklabels(bin_labels) + ax.set_title(cochlea) + ax.legend() + + ax.set_xlabel("Octave band (kHz)") + plt.tight_layout() + if show_plots: + plt.show() + else: + plt.savefig("./subtype_plots/overview_tonotopic_mapping.png") + plt.close() + + # + # Create the overview figure. + # + summary, types = {}, [] + for cochlea, result in results.items(): + if cochlea == "M_LR_000214_L": # One of the signals cannot be analyzed. + continue + + classification = result["classification"] + classification = [cls[:cls.find(" (")] for cls in classification] + n_tot = len(classification) + + this_types = list(set(classification)) + types.extend(this_types) + summary[cochlea] = {} + for stype in types: + n_type = len([cls for cls in classification if cls == stype]) + type_ratio = float(n_type) / n_tot + summary[cochlea][stype] = type_ratio + + types = list(set(types)) + df = pd.DataFrame(summary).fillna(0) # missing values → 0 + + # Transpose → cochleae on x-axis, subtypes stacked + ax = df.T.plot(kind="bar", stacked=True, figsize=(8, 5)) + + ax.set_ylabel("Fraction") + ax.set_xlabel("Cochlea") + ax.set_title("Subtype Fractions per Cochlea") + plt.xticks(rotation=0) + plt.tight_layout() + + if show_plots: + plt.show() + else: + plt.savefig("./subtype_plots/overview.png") + plt.close() + def analyze_subtype_data_regular(show_plots=True): global PLOT_OUT, COLORS # noqa PLOT_OUT = "subtype_plots/regular_mice" files = sorted(glob("./subtype_analysis/*.tsv")) + results = {} for ff in files: cochlea = os.path.basename(ff)[:-len("_subtype_analysis.tsv")] @@ -354,10 +477,10 @@ def analyze_subtype_data_regular(show_plots=True): tab = pd.read_csv(ff, sep="\t") # 1.) Plot simple intensity histograms, including otsu threshold. - # for chan in channels: - # column = f"{chan}_median" - # name = f"{cochlea}_{chan}_histogram" - # _plot_histogram(tab, column, name, show_plots, apply_threshold=chan != reference_channel) + for chan in channels: + column = f"{chan}_median" + name = f"{cochlea}_{chan}_histogram" + _plot_histogram(tab, column, name, show_plots, apply_threshold=chan != reference_channel) # 2.) Plot ratio histograms, including otsu threshold. ratios = {} @@ -372,9 +495,18 @@ def analyze_subtype_data_regular(show_plots=True): ratios[f"{chan}_{reference_channel}"] = tab[column].values # Unify the classification and assign colors - cls1, cls2 = classification[0], classification[1] - assert len(cls1) == len(cls2) - classification = [f"{c1} / {c2}" for c1, c2 in zip(cls1, cls2)] + assert len(classification) in (1, 2) + if len(classification) == 2: + cls1, cls2 = classification[0], classification[1] + assert len(cls1) == len(cls2) + classification = [f"{c1} / {c2}" for c1, c2 in zip(cls1, cls2)] + show_2d = True + else: + classification = classification[0] + show_2d = False + + classification = [stain_to_type(cls) for cls in classification] + classification = [f"{stype} ({stain})" for stype, stain in classification] unique_labels = set(classification) for label in unique_labels: @@ -391,21 +523,31 @@ def analyze_subtype_data_regular(show_plots=True): freq = tab["frequency[kHz]"].values assert len(freq) == len(classification) name = f"{cochlea}_tonotopic_mapping" - _plot_tonotopic_mapping(freq, classification, name=name, colors=COLORS, show_plots=show_plots) + tonotopic_mapping = _plot_tonotopic_mapping( + freq, classification, name=name, colors=COLORS, show_plots=show_plots + ) # 4.) Plot 2D space of ratios. - name = f"{cochlea}_2d" - _plot_2d(ratios, name, show_plots, classification=classification, colors=COLORS) + if show_2d: + name = f"{cochlea}_2d" + _plot_2d(ratios, name, show_plots, classification=classification, colors=COLORS) + results[cochlea] = {"classification": classification, "tonotopic_mapping": tonotopic_mapping} -# General notes: -# See: + combined_analysis(results, show_plots=show_plots) + + +# More TODO: +# > It's good to see that for the N mice the Ntng1C and Lypd1 separate from CR so well on the thresholds. +# Can I visualize these samples ones segmentation masks are done to verify the Ntng1C thresholds? +# As this is a quite clear signal I'm not sure if taking the middle of the histogram would be the best choice. +# The segmentations are in MoBIE already. I need to send you the tables for analyzing the signals. Will send them later. def main(): - missing_tables = check_processing_status() - require_missing_tables(missing_tables) - compile_data_for_subtype_analysis() + # missing_tables = check_processing_status() + # require_missing_tables(missing_tables) + # compile_data_for_subtype_analysis() - # analyze_subtype_data_regular(show_plots=False) + analyze_subtype_data_regular(show_plots=False) # TODO # analyze_subtype_data_N_mice() From ea2f581ddba3956867bf376042a75db23c6c1766 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 18 Sep 2025 19:25:58 +0200 Subject: [PATCH 13/13] Implement support for anisotropic training --- flamingo_tools/training/util.py | 5 ++++- scripts/training/train_distance_unet.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/flamingo_tools/training/util.py b/flamingo_tools/training/util.py index 5c326d8..ca7fa3d 100644 --- a/flamingo_tools/training/util.py +++ b/flamingo_tools/training/util.py @@ -29,6 +29,7 @@ def get_supervised_loader( label_key: Optional[str] = None, n_samples: Optional[int] = None, raw_transform: Optional[callable] = None, + anisotropy: Optional[float] = None, ) -> DataLoader: """Get a data loader for a supervised segmentation task. @@ -41,14 +42,16 @@ def get_supervised_loader( image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data. n_samples: The number of samples to use for training. raw_transform: Optional transformation for the raw data. + anisotropy: The anisotropy factor for distance target computation. Returns: The data loader. """ assert len(image_paths) == len(label_paths) assert len(image_paths) > 0 + sampling = None if anisotropy is None else (anisotropy, 1.0, 1.0) label_transform = torch_em.transform.label.PerObjectDistanceTransform( - distances=True, boundary_distances=True, foreground=True, + distances=True, boundary_distances=True, foreground=True, sampling=sampling, ) sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8) loader = torch_em.default_segmentation_loader( diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 4a109f0..f38e8ec 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -80,7 +80,7 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru return image_paths, label_paths -def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders): +def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders, anisotropy): if separate_folders: image_paths, label_paths = get_image_and_label_paths_sep_folders(root) else: @@ -96,7 +96,9 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_fold n_samples = 16 * batch_size return ( - get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples), + get_supervised_loader( + this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples, anisotropy=anisotropy + ), this_image_paths, this_label_paths ) @@ -124,6 +126,10 @@ def main(): parser.add_argument( "--name", help="Optional name for the model to be trained. If not given the current date is used." ) + parser.add_argument( + "--anisotropy", help="Anisotropy factor of the Z-Axis (Depth). Will be used to scale distance targets.", + type=float, + ) parser.add_argument("--separate_folders", action="store_true") args = parser.parse_args() root = args.root @@ -141,10 +147,12 @@ def main(): # Create the training loader with train and val set. train_loader, train_images, train_labels = get_loader( - root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders + root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders, + anisotropy=args.anisotropy, ) val_loader, val_images, val_labels = get_loader( - root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders + root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders, + anisotropy=args.anisotropy, ) if check_loaders: