From 938d4baf24d0c6b072daf630a704df024338a546 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 29 Jan 2025 18:57:00 +0100 Subject: [PATCH 1/9] Training is working --- scripts/synapse_marker_detection/.gitignore | 1 + .../detection_dataset.py | 166 ++++++++++++++++++ .../extract_training_data.py | 79 +++++++++ .../train_synapse_detection.py | 85 +++++++++ 4 files changed, 331 insertions(+) create mode 100644 scripts/synapse_marker_detection/.gitignore create mode 100644 scripts/synapse_marker_detection/detection_dataset.py create mode 100644 scripts/synapse_marker_detection/extract_training_data.py create mode 100644 scripts/synapse_marker_detection/train_synapse_detection.py diff --git a/scripts/synapse_marker_detection/.gitignore b/scripts/synapse_marker_detection/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/scripts/synapse_marker_detection/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py new file mode 100644 index 0000000..14b07b6 --- /dev/null +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -0,0 +1,166 @@ +import numpy as np +import pandas as pd +import torch +import zarr + +from skimage.filters import gaussian +from torch_em.util import ensure_tensor_with_channels + + +# Process labels stored in json napari style. +# I don't actually think that we need the epsilon here, but will leave it for now. +def process_labels(label_path, shape, sigma, eps): + labels = np.zeros(shape, dtype="float32") + points = pd.read_csv(label_path) + assert len(points.columns) == len(shape) + coords = tuple( + np.clip(np.round(points[ax].values).astype("int"), 0, shape[i] - 1) + for i, ax in enumerate(points.columns) + ) + labels[coords] = 1 + labels = gaussian(labels, sigma) + # TODO better normalization? + labels /= labels.max() + return labels + + +class DetectionDataset(torch.utils.data.Dataset): + max_sampling_attempts = 500 + + def __init__( + self, + raw_image_paths, + label_paths, + patch_shape, + raw_transform=None, + label_transform=None, + transform=None, + dtype=torch.float32, + label_dtype=torch.float32, + n_samples=None, + sampler=None, + eps=1e-8, + sigma=None, + **kwargs, + ): + self.raw_images = raw_image_paths + # TODO make this a parameter + self.raw_key = "raw" + self.label_images = label_paths + self._ndim = 3 + + assert len(patch_shape) == self._ndim + self.patch_shape = patch_shape + + self.raw_transform = raw_transform + self.label_transform = label_transform + self.transform = transform + self.sampler = sampler + + self.dtype = dtype + self.label_dtype = label_dtype + + self.eps = eps + self.sigma = sigma + + if n_samples is None: + self._len = len(self.raw_images) + self.sample_random_index = False + else: + self._len = n_samples + self.sample_random_index = True + + def __len__(self): + return self._len + + @property + def ndim(self): + return self._ndim + + def _sample_bounding_box(self, shape): + if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): + raise NotImplementedError( + f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}" + ) + bb_start = [ + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 + for sh, psh in zip(shape, self.patch_shape) + ] + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) + + def _get_sample(self, index): + if self.sample_random_index: + index = np.random.randint(0, len(self.raw_images)) + raw, label = self.raw_images[index], self.label_images[index] + + raw = zarr.open(raw)[self.raw_key] + # Note: this is quite inefficient, because we process the full crop rather than + # just the requested bounding box. + label = process_labels(label, raw.shape, self.sigma, self.eps) + + have_raw_channels = raw.ndim == 4 # 3D with channels + have_label_channels = label.ndim == 4 + if have_label_channels: + raise NotImplementedError("Multi-channel labels are not supported.") + + shape = raw.shape + prefix_box = tuple() + if have_raw_channels: + if shape[-1] < 16: + shape = shape[:-1] + else: + shape = shape[1:] + prefix_box = (slice(None), ) + + bb = self._sample_bounding_box(shape) + raw_patch = np.array(raw[prefix_box + bb]) + label_patch = np.array(label[bb]) + + if self.sampler is not None: + sample_id = 0 + while not self.sampler(raw_patch, label_patch): + bb = self._sample_bounding_box(shape) + raw_patch = np.array(raw[prefix_box + bb]) + label_patch = np.array(label[bb]) + sample_id += 1 + if sample_id > self.max_sampling_attempts: + raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + + if have_raw_channels and len(prefix_box) == 0: + raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width + + return raw_patch, label_patch + + def __getitem__(self, index): + raw, labels = self._get_sample(index) + # initial_label_dtype = labels.dtype + + if self.raw_transform is not None: + raw = self.raw_transform(raw) + + if self.label_transform is not None: + labels = self.label_transform(labels) + + if self.transform is not None: + raw, labels = self.transform(raw, labels) + + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) + return raw, labels + + +if __name__ == "__main__": + import napari + + raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr" + label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv" + + f = zarr.open(raw_path, "r") + raw = f["raw"][:] + + labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7) + + v = napari.Viewer() + v.add_image(raw) + v.add_image(labels) + napari.run() diff --git a/scripts/synapse_marker_detection/extract_training_data.py b/scripts/synapse_marker_detection/extract_training_data.py new file mode 100644 index 0000000..3017577 --- /dev/null +++ b/scripts/synapse_marker_detection/extract_training_data.py @@ -0,0 +1,79 @@ +import os +from glob import glob +from pathlib import Path + +import h5py +import napari +import numpy as np +import pandas as pd +import zarr + + +def get_voxel_size(imaris_file): + with h5py.File(imaris_file, "r") as f: + info = f["/DataSetInfo/Image"] + ext = [[float(b"".join(info.attrs[f"ExtMin{i}"]).decode()), + float(b"".join(info.attrs[f"ExtMax{i}"]).decode())] for i in range(3)] + size = [int(b"".join(info.attrs[dim]).decode()) for dim in ["X", "Y", "Z"]] + vsize = np.array([(max_-min_)/s for (min_, max_), s in zip(ext, size)]) + return vsize + + +def extract_training_data(imaris_file, output_folder): + with h5py.File(imaris_file, "r") as f: + data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:] + points = f["/Scene/Content/Points0/CoordsXYZR"][:] + points = points[:, :-1] + points = points[:, ::-1] + + # TODO crop the data to the original shape. + # Can we just crop the zero-padding ?! + crop_box = np.where(data != 0) + crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box) + data = data[crop_box] + print(data.shape) + + # Scale the points to match the image dimensions. + voxel_size = get_voxel_size(imaris_file) + points /= voxel_size[None] + + if output_folder is None: + v = napari.Viewer() + v.add_image(data) + v.add_points(points) + v.title = os.path.basename(imaris_file) + napari.run() + else: + image_folder = os.path.join(output_folder, "images") + os.makedirs(image_folder, exist_ok=True) + + label_folder = os.path.join(output_folder, "labels") + os.makedirs(label_folder, exist_ok=True) + + fname = Path(imaris_file).stem + image_file = os.path.join(image_folder, f"{fname}.zarr") + label_file = os.path.join(label_folder, f"{fname}.csv") + + coords = pd.DataFrame(points, columns=["axis-0", "axis-1", "axis-2"]) + coords.to_csv(label_file, index=False) + + f = zarr.open(image_file, "a") + f.create_dataset("raw", data=data) + + +# Files that look good for training: +# - 4.1L_apex_IHCribboncount_Z.ims +# - 4.1L_base_IHCribbons_Z.ims +# - 4.1L_mid_IHCribboncount_Z.ims +# - 4.2R_apex_IHCribboncount_Z.ims +# - 4.2R_apex_IHCribboncount_Z.ims +# - 6.2R_apex_IHCribboncount_Z.ims (very small crop) +# - 6.2R_base_IHCribbons_Z.ims +def main(): + files = sorted(glob("./data/synapse_stains/*.ims")) + for ff in files: + extract_training_data(ff, output_folder="./training_data") + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py new file mode 100644 index 0000000..282c52c --- /dev/null +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -0,0 +1,85 @@ +import os +import sys + +from detection_dataset import DetectionDataset + +# sys.path.append() +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") + +from utils.training import supervised_training # noqa + +TRAIN_ROOT = "./training_data/images" +LABEL_ROOT = "./training_data/labels" + + +def get_paths(split): + file_names = [ + "4.1L_apex_IHCribboncount_Z", + "4.1L_base_IHCribbons_Z", + "4.1L_mid_IHCribboncount_Z", + "4.2R_apex_IHCribboncount_Z", + "4.2R_apex_IHCribboncount_Z", + "6.2R_apex_IHCribboncount_Z", + "6.2R_base_IHCribbons_Z", + ] + image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names] + label_paths = [os.path.join(LABEL_ROOT, f"{fname}.csv") for fname in file_names] + + if split == "train": + image_paths = image_paths[:-1] + label_paths = label_paths[:-1] + else: + image_paths = image_paths[-1:] + label_paths = label_paths[-1:] + + return image_paths, label_paths + + +# TODO maybe add a sampler for the label data +def train(): + + model_name = "synapse_detection_v1" + + train_paths, train_label_paths = get_paths("train") + val_paths, val_label_paths = get_paths("val") + # We need to give the paths for the test loader, although it's never used. + test_paths, test_label_paths = val_paths, val_label_paths + + print("Start training with:") + print(len(train_paths), "tomograms for training") + print(len(val_paths), "tomograms for validation") + + patch_shape = [32, 96, 96] + + batch_size = 8 + check = False + + supervised_training( + name=model_name, + train_paths=train_paths, + train_label_paths=train_label_paths, + val_paths=val_paths, + val_label_paths=val_label_paths, + patch_shape=patch_shape, batch_size=batch_size, + check=check, + lr=1e-4, + n_iterations=int(2.5e4), + out_channels=1, + augmentations=None, + eps=1e-5, + sigma=1, + lower_bound=None, + upper_bound=None, + test_paths=test_paths, + test_label_paths=test_label_paths, + # save_root="", + dataset_class=DetectionDataset, + ) + + +def main(): + train() + + +if __name__ == "__main__": + main() From 9db6a83dd1f08da801834989750a8f19a3194f20 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 30 Jan 2025 10:40:35 +0100 Subject: [PATCH 2/9] Add sample sizes --- scripts/synapse_marker_detection/train_synapse_detection.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index 282c52c..ebd14b0 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -3,8 +3,8 @@ from detection_dataset import DetectionDataset -# sys.path.append() -sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") from utils.training import supervised_training # noqa @@ -74,6 +74,8 @@ def train(): test_label_paths=test_label_paths, # save_root="", dataset_class=DetectionDataset, + n_samples_train=800, + n_samples_val=80, ) From 6da20c524ca4502d2522fc64a6f67215ae3d26d2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 2 Feb 2025 20:46:35 +0100 Subject: [PATCH 3/9] Add script to check synapse prediction --- .../check_synapse_prediction.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 scripts/synapse_marker_detection/check_synapse_prediction.py diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py new file mode 100644 index 0000000..d31ddf1 --- /dev/null +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -0,0 +1,25 @@ +import h5py +import zarr +from torch_em.util import load_model +from torch_em.util.prediction import predict_with_halo +from train_synapse_detection import get_paths + + +def run_prediction(val_image): + model = load_model("./checkpoints/synapse_detection_v1") + block_shape = (32, 384, 384) + halo = (8, 64, 64) + pred = predict_with_halo(val_image, model, [0], block_shape, halo) + return pred.squeeze() + + +def main(): + val_paths, _ = get_paths("val") + val_image = zarr.open(val_paths[0])["raw"][:] + pred = run_prediction(val_image) + with h5py.File("pred.h5", "a") as f: + f.create_dataset("pred", data=pred, compression="gzip") + + +if __name__ == "__main__": + main() From 59f450ac0fd6bce4e1eeffb7fb1f9450c2409d41 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 10 Feb 2025 20:36:27 +0100 Subject: [PATCH 4/9] Update syn marker visualization script --- .../check_synapse_prediction.py | 17 ++++++++++++++--- .../train_synapse_detection.py | 4 ++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py index d31ddf1..970c320 100644 --- a/scripts/synapse_marker_detection/check_synapse_prediction.py +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -1,5 +1,7 @@ import h5py +import napari import zarr + from torch_em.util import load_model from torch_em.util.prediction import predict_with_halo from train_synapse_detection import get_paths @@ -16,9 +18,18 @@ def run_prediction(val_image): def main(): val_paths, _ = get_paths("val") val_image = zarr.open(val_paths[0])["raw"][:] - pred = run_prediction(val_image) - with h5py.File("pred.h5", "a") as f: - f.create_dataset("pred", data=pred, compression="gzip") + + # pred = run_prediction(val_image) + # with h5py.File("pred.h5", "a") as f: + # f.create_dataset("pred", data=pred, compression="gzip") + + with h5py.File("pred.h5", "r") as f: + pred = f["pred"][:] + + v = napari.Viewer() + v.add_image(val_image) + v.add_image(pred) + napari.run() if __name__ == "__main__": diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index ebd14b0..7a42450 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -3,8 +3,8 @@ from detection_dataset import DetectionDataset -# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") -sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +# sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") from utils.training import supervised_training # noqa From 6a1236d99ee8868e79541f391d64931712b0e2b4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 12 Feb 2025 22:17:46 +0100 Subject: [PATCH 5/9] Updates to synapse training --- .../detection_dataset.py | 92 ++++++++++++------- .../train_synapse_detection.py | 14 +-- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py index 14b07b6..b194bb1 100644 --- a/scripts/synapse_marker_detection/detection_dataset.py +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -9,29 +9,61 @@ # Process labels stored in json napari style. # I don't actually think that we need the epsilon here, but will leave it for now. -def process_labels(label_path, shape, sigma, eps): - labels = np.zeros(shape, dtype="float32") +def process_labels(label_path, shape, sigma, eps, bb=None): points = pd.read_csv(label_path) + + if bb: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) + labels = np.zeros(restricted_shape, dtype="float32") + shape = restricted_shape + else: + labels = np.zeros(shape, dtype="float32") + assert len(points.columns) == len(shape) + z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"] + if bb is not None: + z_coords -= z_min + y_coords -= y_min + x_coords -= x_min + mask = np.logical_and.reduce([ + np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)), + np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)), + np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)), + ]) + z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask] + coords = tuple( - np.clip(np.round(points[ax].values).astype("int"), 0, shape[i] - 1) - for i, ax in enumerate(points.columns) + np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip( + (z_coords, y_coords, x_coords), shape + ) ) + labels[coords] = 1 labels = gaussian(labels, sigma) # TODO better normalization? - labels /= labels.max() + labels /= (labels.max() + 1e-7) + labels *= 4 return labels class DetectionDataset(torch.utils.data.Dataset): max_sampling_attempts = 500 + @staticmethod + def compute_len(shape, patch_shape): + if patch_shape is None: + return 1 + else: + n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) + return n_samples + def __init__( self, - raw_image_paths, - label_paths, + raw_path, + label_path, patch_shape, + raw_key, raw_transform=None, label_transform=None, transform=None, @@ -43,10 +75,9 @@ def __init__( sigma=None, **kwargs, ): - self.raw_images = raw_image_paths - # TODO make this a parameter - self.raw_key = "raw" - self.label_images = label_paths + self.raw_path = raw_path + self.label_path = label_path + self.raw_key = raw_key self._ndim = 3 assert len(patch_shape) == self._ndim @@ -63,12 +94,13 @@ def __init__( self.eps = eps self.sigma = sigma + with zarr.open(self.raw_path, "r") as f: + self.shape = f[self.raw_key].shape + if n_samples is None: - self._len = len(self.raw_images) - self.sample_random_index = False + self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples else: self._len = n_samples - self.sample_random_index = True def __len__(self): return self._len @@ -89,21 +121,19 @@ def _sample_bounding_box(self, shape): return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) def _get_sample(self, index): - if self.sample_random_index: - index = np.random.randint(0, len(self.raw_images)) - raw, label = self.raw_images[index], self.label_images[index] + raw, label_path = self.raw_path, self.label_path raw = zarr.open(raw)[self.raw_key] - # Note: this is quite inefficient, because we process the full crop rather than - # just the requested bounding box. - label = process_labels(label, raw.shape, self.sigma, self.eps) + shape = raw.shape + + bb = self._sample_bounding_box(shape) + label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb) have_raw_channels = raw.ndim == 4 # 3D with channels have_label_channels = label.ndim == 4 if have_label_channels: raise NotImplementedError("Multi-channel labels are not supported.") - shape = raw.shape prefix_box = tuple() if have_raw_channels: if shape[-1] < 16: @@ -112,19 +142,19 @@ def _get_sample(self, index): shape = shape[1:] prefix_box = (slice(None), ) - bb = self._sample_bounding_box(shape) raw_patch = np.array(raw[prefix_box + bb]) - label_patch = np.array(label[bb]) + label_patch = np.array(label) if self.sampler is not None: - sample_id = 0 - while not self.sampler(raw_patch, label_patch): - bb = self._sample_bounding_box(shape) - raw_patch = np.array(raw[prefix_box + bb]) - label_patch = np.array(label[bb]) - sample_id += 1 - if sample_id > self.max_sampling_attempts: - raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + assert False, "Sampler not implemented" + # sample_id = 0 + # while not self.sampler(raw_patch, label_patch): + # bb = self._sample_bounding_box(shape) + # raw_patch = np.array(raw[prefix_box + bb]) + # label_patch = np.array(label[bb]) + # sample_id += 1 + # if sample_id > self.max_sampling_attempts: + # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") if have_raw_channels and len(prefix_box) == 0: raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index ebd14b0..92a1085 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -6,7 +6,7 @@ # sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") -from utils.training import supervised_training # noqa +from utils.training.training import supervised_training # noqa TRAIN_ROOT = "./training_data/images" LABEL_ROOT = "./training_data/labels" @@ -49,9 +49,8 @@ def train(): print(len(train_paths), "tomograms for training") print(len(val_paths), "tomograms for validation") - patch_shape = [32, 96, 96] - - batch_size = 8 + patch_shape = [40, 112, 112] + batch_size = 32 check = False supervised_training( @@ -60,10 +59,11 @@ def train(): train_label_paths=train_label_paths, val_paths=val_paths, val_label_paths=val_label_paths, + raw_key="raw", patch_shape=patch_shape, batch_size=batch_size, check=check, lr=1e-4, - n_iterations=int(2.5e4), + n_iterations=int(5e4), out_channels=1, augmentations=None, eps=1e-5, @@ -74,8 +74,8 @@ def train(): test_label_paths=test_label_paths, # save_root="", dataset_class=DetectionDataset, - n_samples_train=800, - n_samples_val=80, + n_samples_train=3200, + n_samples_val=160, ) From b788da7d22f213e72355491d036c67ee0243a47f Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 6 Mar 2025 15:53:26 +0100 Subject: [PATCH 6/9] Update prediction check --- .../check_synapse_prediction.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py index 970c320..9e12548 100644 --- a/scripts/synapse_marker_detection/check_synapse_prediction.py +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -1,7 +1,10 @@ import h5py import napari +import pandas as pd import zarr +# from skimage.feature import blob_dog +from skimage.feature import peak_local_max from torch_em.util import load_model from torch_em.util.prediction import predict_with_halo from train_synapse_detection import get_paths @@ -16,8 +19,9 @@ def run_prediction(val_image): def main(): - val_paths, _ = get_paths("val") + val_paths, val_labels = get_paths("val") val_image = zarr.open(val_paths[0])["raw"][:] + val_labels = pd.read_csv(val_labels[0])[["axis-0", "axis-1", "axis-2"]] # pred = run_prediction(val_image) # with h5py.File("pred.h5", "a") as f: @@ -26,9 +30,17 @@ def main(): with h5py.File("pred.h5", "r") as f: pred = f["pred"][:] + print("Running local max ...") + # coords = blob_dog(pred) + coords = peak_local_max(pred, min_distance=2, threshold_abs=0.2) + # breakpoint() + print("... done") + v = napari.Viewer() v.add_image(val_image) v.add_image(pred) + v.add_points(coords) + v.add_points(val_labels, face_color="green") napari.run() From 81db81e4cbd13922c81c73d1dcf94093373afb09 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 16 Mar 2025 18:32:42 +0100 Subject: [PATCH 7/9] Update synapse prediction script --- .../check_synapse_prediction.py | 61 ++++++++++++++++--- .../train_synapse_detection.py | 9 +-- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py index 970c320..86f50d2 100644 --- a/scripts/synapse_marker_detection/check_synapse_prediction.py +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -1,10 +1,17 @@ +import os +from glob import glob + import h5py +import imageio.v3 as imageio import napari import zarr from torch_em.util import load_model from torch_em.util.prediction import predict_with_halo from train_synapse_detection import get_paths +from tqdm import tqdm + +OUTPUT_ROOT = "./predictions" def run_prediction(val_image): @@ -15,22 +22,56 @@ def run_prediction(val_image): return pred.squeeze() -def main(): - val_paths, _ = get_paths("val") - val_image = zarr.open(val_paths[0])["raw"][:] - - # pred = run_prediction(val_image) - # with h5py.File("pred.h5", "a") as f: - # f.create_dataset("pred", data=pred, compression="gzip") +def require_prediction(image_data, output_path): + key = "prediction" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + pred = f[key][:] + else: + pred = run_prediction(image_data) + with h5py.File(output_path, "w") as f: + f.create_dataset(key, data=pred, compression="gzip") + return pred - with h5py.File("pred.h5", "r") as f: - pred = f["pred"][:] +def visualize_results(image_data, pred): v = napari.Viewer() - v.add_image(val_image) + v.add_image(image_data) v.add_image(pred) napari.run() +def check_val_image(): + val_paths, _ = get_paths("val") + val_path = val_paths[0] + val_image = zarr.open(val_path)["raw"][:] + + os.makedirs(os.path.join(OUTPUT_ROOT, "val"), exist_ok=True) + output_path = os.path.join(OUTPUT_ROOT, "val", os.path.basename(val_path).replace(".zarr", ".h5")) + pred = require_prediction(val_image, output_path) + + visualize_results(val_image, pred) + + +def check_new_images(): + input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops" + inputs = glob(os.path.join(input_root, "*.tif")) + output_folder = os.path.join(OUTPUT_ROOT, "new_crops") + os.makedirs(output_folder, exist_ok=True) + for path in tqdm(inputs): + name = os.path.basename(path) + if name == "M_AMD_58L_avgblendfused_RibB.tif": + continue + image_data = imageio.imread(path) + output_path = os.path.join(output_folder, name.replace(".tif", ".h5")) + require_prediction(image_data, output_path) + + +# TODO update to support post-processing and showing annotations for the val data +def main(): + # check_val_image() + check_new_images() + + if __name__ == "__main__": main() diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index bfac242..ea6ae78 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -3,13 +3,14 @@ from detection_dataset import DetectionDataset -sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") -# sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") +# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") from utils.training.training import supervised_training # noqa -TRAIN_ROOT = "./training_data/images" -LABEL_ROOT = "./training_data/labels" +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v1" # noqa +TRAIN_ROOT = os.path.join(ROOT, "images") +LABEL_ROOT = os.path.join(ROOT, "labels") def get_paths(split): From cd0cc0edb9ff2320ea5b747957a512447b948b89 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 4 Apr 2025 15:05:30 +0200 Subject: [PATCH 8/9] Update synapse detection scripts --- .../check_synapse_prediction.py | 41 ++++++++++++++----- .../train_synapse_detection.py | 2 +- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py index e12fbaf..5cad371 100644 --- a/scripts/synapse_marker_detection/check_synapse_prediction.py +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -1,9 +1,11 @@ import os from glob import glob +from pathlib import Path import h5py import imageio.v3 as imageio import napari +import numpy as np import pandas as pd import zarr @@ -14,7 +16,10 @@ from train_synapse_detection import get_paths from tqdm import tqdm +# INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops" +INPUT_ROOT = "./data/test_crops" OUTPUT_ROOT = "./predictions" +DETECTION_OUT_ROOT = "./detections" def run_prediction(val_image): @@ -40,19 +45,22 @@ def require_prediction(image_data, output_path): def run_postprocessing(pred): # print("Running local max ...") # coords = blob_dog(pred) - coords = peak_local_max(pred, min_distance=2, threshold_abs=0.2) + coords = peak_local_max(pred, min_distance=2, threshold_abs=0.5) # print("... done") return coords -def visualize_results(image_data, pred, coords=None, val_coords=None): +def visualize_results(image_data, pred, coords=None, val_coords=None, title=None): v = napari.Viewer() v.add_image(image_data) + pred = pred.clip(0, pred.max()) v.add_image(pred) - if coords is None: - v.add_points(coords, name="predicted_synapses") - if val_coords is None: + if coords is not None: + v.add_points(coords, name="predicted_synapses", face_color="yellow") + if val_coords is not None: v.add_points(val_coords, face_color="green", name="synapse_annotations") + if title is not None: + v.title = title napari.run() @@ -68,9 +76,8 @@ def check_val_image(): visualize_results(val_image, pred) -def check_new_images(): - input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops" - inputs = glob(os.path.join(input_root, "*.tif")) +def check_new_images(view=False, save_detection=False): + inputs = glob(os.path.join(INPUT_ROOT, "*.tif")) output_folder = os.path.join(OUTPUT_ROOT, "new_crops") os.makedirs(output_folder, exist_ok=True) for path in tqdm(inputs): @@ -80,13 +87,27 @@ def check_new_images(): continue image_data = imageio.imread(path) output_path = os.path.join(output_folder, name.replace(".tif", ".h5")) - require_prediction(image_data, output_path) + # if not os.path.exists(output_path): + # continue + pred = require_prediction(image_data, output_path) + if view or save_detection: + coords = run_postprocessing(pred) + if view: + print("Number of synapses:", len(coords)) + visualize_results(image_data, pred, coords=coords, title=name) + if save_detection: + os.makedirs(DETECTION_OUT_ROOT, exist_ok=True) + coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1) + coords = pd.DataFrame(coords, columns=["index", "axis-0", "axis-1", "axis-2"]) + fname = Path(path).stem + detection_save_path = os.path.join(DETECTION_OUT_ROOT, f"{fname}.csv") + coords.to_csv(detection_save_path, index=False) # TODO update to support post-processing and showing annotations for the val data def main(): # check_val_image() - check_new_images() + check_new_images(view=False, save_detection=True) if __name__ == "__main__": diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py index ea6ae78..2a7d6af 100644 --- a/scripts/synapse_marker_detection/train_synapse_detection.py +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -3,7 +3,7 @@ from detection_dataset import DetectionDataset -# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") from utils.training.training import supervised_training # noqa From 3230fa283317b70186a4ae9d0d30a0c2085a6448 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 13 Apr 2025 12:13:02 +0200 Subject: [PATCH 9/9] Updates to support large-scale prediction with the synapse detection model --- .../segmentation/unet_prediction.py | 38 ++++++++++--- .../run_prediction.py | 56 +++++++++++++++++++ 2 files changed, 85 insertions(+), 9 deletions(-) create mode 100644 scripts/synapse_marker_detection/run_prediction.py diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index fde52a0..5cc2228 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -8,6 +8,7 @@ import numpy as np import nifty.tools as nt import vigra +import tifffile import torch import z5py @@ -37,7 +38,10 @@ def ndim(self): return self._volume.ndim - 1 -def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo): +def prediction_impl( + input_path, input_key, output_folder, model_path, scale, block_shape, halo, + output_channels=3, apply_postprocessing=True, +): with warnings.catch_warnings(): warnings.simplefilter("ignore") if os.path.isdir(model_path): @@ -46,10 +50,16 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo model = torch.load(model_path) mask_path = os.path.join(output_folder, "mask.zarr") - image_mask = z5py.File(mask_path, "r")["mask"] + if os.path.exists(mask_path): + image_mask = z5py.File(mask_path, "r")["mask"] + else: + image_mask = None if input_key is None: - input_ = imageio.imread(input_path) + try: + input_ = tifffile.memmap(input_path) + except Exception: + input_ = imageio.imread(input_path) else: input_ = open_file(input_path, "r")[input_key] @@ -93,17 +103,27 @@ def preprocess(raw): raw /= std return raw - # Smooth the distance prediction channel. - def postprocess(x): - x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0) - return x + if apply_postprocessing: + # Smooth the distance prediction channel. + def postprocess(x): + x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0) + return x + else: + postprocess = None if output_channels > 1 else lambda x: x.squeeze() + + if output_channels > 1: + output_shape = (output_channels,) + input_.shape + output_chunks = (1,) + block_shape + else: + output_shape = input_.shape + output_chunks = block_shape output_path = os.path.join(output_folder, "predictions.zarr") with open_file(output_path, "a") as f: output = f.require_dataset( "prediction", - shape=(3,) + input_.shape, - chunks=(1,) + block_shape, + shape=output_shape, + chunks=output_chunks, compression="gzip", dtype="float32", ) diff --git a/scripts/synapse_marker_detection/run_prediction.py b/scripts/synapse_marker_detection/run_prediction.py new file mode 100644 index 0000000..feb4ad7 --- /dev/null +++ b/scripts/synapse_marker_detection/run_prediction.py @@ -0,0 +1,56 @@ +import argparse +import os +import sys + +import pandas as pd +import numpy as np +import zarr + +from elf.parallel.local_maxima import find_local_maxima + +sys.path.append("../..") + + +def main(): + from flamingo_tools.segmentation.unet_prediction import prediction_impl + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-o", "--output_folder", required=True) + parser.add_argument("-m", "--model", required=True) + parser.add_argument("-k", "--input_key", default=None) + args = parser.parse_args() + + block_shape = (64, 256, 256) + halo = (16, 64, 64) + + # Skip existing prediction, which is saved in output_folder/predictions.zarr + skip_prediction = False + output_path = os.path.join(args.output_folder, "predictions.zarr") + prediction_key = "prediction" + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): + skip_prediction = True + + if not skip_prediction: + prediction_impl( + args.input, args.input_key, args.output_folder, args.model, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + detection_path = os.path.join(args.output_folder, "synapse_detection.tsv") + if not os.path.exists(detection_path): + input_ = zarr.open(output_path, "r")[prediction_key] + detections = find_local_maxima( + input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16, + ) + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + detections.to_csv(detection_path, index=False, sep="\t") + + +if __name__ == "__main__": + main()