diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index bd1670a..2e0163e 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -60,6 +60,8 @@ def prediction_impl( scale, block_shape, halo, + output_channels=3, + apply_postprocessing=True, prediction_instances=1, slurm_task_id=0, mean=None, @@ -75,7 +77,10 @@ def prediction_impl( model = torch.load(model_path, weights_only=False) mask_path = os.path.join(output_folder, "mask.zarr") - image_mask = z5py.File(mask_path, "r")["mask"] + if os.path.exists(mask_path): + image_mask = z5py.File(mask_path, "r")["mask"] + else: + image_mask = None input_ = read_image_data(input_path, input_key) chunks = getattr(input_, "chunks", (64, 64, 64)) @@ -122,10 +127,20 @@ def preprocess(raw): raw /= std return raw - # Smooth the distance prediction channel. - def postprocess(x): - x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0) - return x + if apply_postprocessing: + # Smooth the distance prediction channel. + def postprocess(x): + x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0) + return x + else: + postprocess = None if output_channels > 1 else lambda x: x.squeeze() + + if output_channels > 1: + output_shape = (output_channels,) + input_.shape + output_chunks = (1,) + block_shape + else: + output_shape = input_.shape + output_chunks = block_shape shape = input_.shape ndim = len(shape) @@ -142,8 +157,8 @@ def postprocess(x): with open_file(output_path, "a") as f: output = f.require_dataset( "prediction", - shape=(3,) + input_.shape, - chunks=(1,) + block_shape, + shape=output_shape, + chunks=output_chunks, compression="gzip", dtype="float32", ) diff --git a/scripts/synapse_marker_detection/.gitignore b/scripts/synapse_marker_detection/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/scripts/synapse_marker_detection/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/scripts/synapse_marker_detection/check_synapse_prediction.py b/scripts/synapse_marker_detection/check_synapse_prediction.py new file mode 100644 index 0000000..5cad371 --- /dev/null +++ b/scripts/synapse_marker_detection/check_synapse_prediction.py @@ -0,0 +1,114 @@ +import os +from glob import glob +from pathlib import Path + +import h5py +import imageio.v3 as imageio +import napari +import numpy as np +import pandas as pd +import zarr + +# from skimage.feature import blob_dog +from skimage.feature import peak_local_max +from torch_em.util import load_model +from torch_em.util.prediction import predict_with_halo +from train_synapse_detection import get_paths +from tqdm import tqdm + +# INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops" +INPUT_ROOT = "./data/test_crops" +OUTPUT_ROOT = "./predictions" +DETECTION_OUT_ROOT = "./detections" + + +def run_prediction(val_image): + model = load_model("./checkpoints/synapse_detection_v1") + block_shape = (32, 384, 384) + halo = (8, 64, 64) + pred = predict_with_halo(val_image, model, [0], block_shape, halo) + return pred.squeeze() + + +def require_prediction(image_data, output_path): + key = "prediction" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + pred = f[key][:] + else: + pred = run_prediction(image_data) + with h5py.File(output_path, "w") as f: + f.create_dataset(key, data=pred, compression="gzip") + return pred + + +def run_postprocessing(pred): + # print("Running local max ...") + # coords = blob_dog(pred) + coords = peak_local_max(pred, min_distance=2, threshold_abs=0.5) + # print("... done") + return coords + + +def visualize_results(image_data, pred, coords=None, val_coords=None, title=None): + v = napari.Viewer() + v.add_image(image_data) + pred = pred.clip(0, pred.max()) + v.add_image(pred) + if coords is not None: + v.add_points(coords, name="predicted_synapses", face_color="yellow") + if val_coords is not None: + v.add_points(val_coords, face_color="green", name="synapse_annotations") + if title is not None: + v.title = title + napari.run() + + +def check_val_image(): + val_paths, _ = get_paths("val") + val_path = val_paths[0] + val_image = zarr.open(val_path)["raw"][:] + + os.makedirs(os.path.join(OUTPUT_ROOT, "val"), exist_ok=True) + output_path = os.path.join(OUTPUT_ROOT, "val", os.path.basename(val_path).replace(".zarr", ".h5")) + pred = require_prediction(val_image, output_path) + + visualize_results(val_image, pred) + + +def check_new_images(view=False, save_detection=False): + inputs = glob(os.path.join(INPUT_ROOT, "*.tif")) + output_folder = os.path.join(OUTPUT_ROOT, "new_crops") + os.makedirs(output_folder, exist_ok=True) + for path in tqdm(inputs): + print(path) + name = os.path.basename(path) + if name == "M_AMD_58L_avgblendfused_RibB.tif": + continue + image_data = imageio.imread(path) + output_path = os.path.join(output_folder, name.replace(".tif", ".h5")) + # if not os.path.exists(output_path): + # continue + pred = require_prediction(image_data, output_path) + if view or save_detection: + coords = run_postprocessing(pred) + if view: + print("Number of synapses:", len(coords)) + visualize_results(image_data, pred, coords=coords, title=name) + if save_detection: + os.makedirs(DETECTION_OUT_ROOT, exist_ok=True) + coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1) + coords = pd.DataFrame(coords, columns=["index", "axis-0", "axis-1", "axis-2"]) + fname = Path(path).stem + detection_save_path = os.path.join(DETECTION_OUT_ROOT, f"{fname}.csv") + coords.to_csv(detection_save_path, index=False) + + +# TODO update to support post-processing and showing annotations for the val data +def main(): + # check_val_image() + check_new_images(view=False, save_detection=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/detection_dataset.py b/scripts/synapse_marker_detection/detection_dataset.py new file mode 100644 index 0000000..b194bb1 --- /dev/null +++ b/scripts/synapse_marker_detection/detection_dataset.py @@ -0,0 +1,196 @@ +import numpy as np +import pandas as pd +import torch +import zarr + +from skimage.filters import gaussian +from torch_em.util import ensure_tensor_with_channels + + +# Process labels stored in json napari style. +# I don't actually think that we need the epsilon here, but will leave it for now. +def process_labels(label_path, shape, sigma, eps, bb=None): + points = pd.read_csv(label_path) + + if bb: + (z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb] + restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min) + labels = np.zeros(restricted_shape, dtype="float32") + shape = restricted_shape + else: + labels = np.zeros(shape, dtype="float32") + + assert len(points.columns) == len(shape) + z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"] + if bb is not None: + z_coords -= z_min + y_coords -= y_min + x_coords -= x_min + mask = np.logical_and.reduce([ + np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)), + np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)), + np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)), + ]) + z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask] + + coords = tuple( + np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip( + (z_coords, y_coords, x_coords), shape + ) + ) + + labels[coords] = 1 + labels = gaussian(labels, sigma) + # TODO better normalization? + labels /= (labels.max() + 1e-7) + labels *= 4 + return labels + + +class DetectionDataset(torch.utils.data.Dataset): + max_sampling_attempts = 500 + + @staticmethod + def compute_len(shape, patch_shape): + if patch_shape is None: + return 1 + else: + n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) + return n_samples + + def __init__( + self, + raw_path, + label_path, + patch_shape, + raw_key, + raw_transform=None, + label_transform=None, + transform=None, + dtype=torch.float32, + label_dtype=torch.float32, + n_samples=None, + sampler=None, + eps=1e-8, + sigma=None, + **kwargs, + ): + self.raw_path = raw_path + self.label_path = label_path + self.raw_key = raw_key + self._ndim = 3 + + assert len(patch_shape) == self._ndim + self.patch_shape = patch_shape + + self.raw_transform = raw_transform + self.label_transform = label_transform + self.transform = transform + self.sampler = sampler + + self.dtype = dtype + self.label_dtype = label_dtype + + self.eps = eps + self.sigma = sigma + + with zarr.open(self.raw_path, "r") as f: + self.shape = f[self.raw_key].shape + + if n_samples is None: + self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples + else: + self._len = n_samples + + def __len__(self): + return self._len + + @property + def ndim(self): + return self._ndim + + def _sample_bounding_box(self, shape): + if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): + raise NotImplementedError( + f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}" + ) + bb_start = [ + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 + for sh, psh in zip(shape, self.patch_shape) + ] + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) + + def _get_sample(self, index): + raw, label_path = self.raw_path, self.label_path + + raw = zarr.open(raw)[self.raw_key] + shape = raw.shape + + bb = self._sample_bounding_box(shape) + label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb) + + have_raw_channels = raw.ndim == 4 # 3D with channels + have_label_channels = label.ndim == 4 + if have_label_channels: + raise NotImplementedError("Multi-channel labels are not supported.") + + prefix_box = tuple() + if have_raw_channels: + if shape[-1] < 16: + shape = shape[:-1] + else: + shape = shape[1:] + prefix_box = (slice(None), ) + + raw_patch = np.array(raw[prefix_box + bb]) + label_patch = np.array(label) + + if self.sampler is not None: + assert False, "Sampler not implemented" + # sample_id = 0 + # while not self.sampler(raw_patch, label_patch): + # bb = self._sample_bounding_box(shape) + # raw_patch = np.array(raw[prefix_box + bb]) + # label_patch = np.array(label[bb]) + # sample_id += 1 + # if sample_id > self.max_sampling_attempts: + # raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + + if have_raw_channels and len(prefix_box) == 0: + raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width + + return raw_patch, label_patch + + def __getitem__(self, index): + raw, labels = self._get_sample(index) + # initial_label_dtype = labels.dtype + + if self.raw_transform is not None: + raw = self.raw_transform(raw) + + if self.label_transform is not None: + labels = self.label_transform(labels) + + if self.transform is not None: + raw, labels = self.transform(raw, labels) + + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) + return raw, labels + + +if __name__ == "__main__": + import napari + + raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr" + label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv" + + f = zarr.open(raw_path, "r") + raw = f["raw"][:] + + labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7) + + v = napari.Viewer() + v.add_image(raw) + v.add_image(labels) + napari.run() diff --git a/scripts/synapse_marker_detection/extract_training_data.py b/scripts/synapse_marker_detection/extract_training_data.py new file mode 100644 index 0000000..3017577 --- /dev/null +++ b/scripts/synapse_marker_detection/extract_training_data.py @@ -0,0 +1,79 @@ +import os +from glob import glob +from pathlib import Path + +import h5py +import napari +import numpy as np +import pandas as pd +import zarr + + +def get_voxel_size(imaris_file): + with h5py.File(imaris_file, "r") as f: + info = f["/DataSetInfo/Image"] + ext = [[float(b"".join(info.attrs[f"ExtMin{i}"]).decode()), + float(b"".join(info.attrs[f"ExtMax{i}"]).decode())] for i in range(3)] + size = [int(b"".join(info.attrs[dim]).decode()) for dim in ["X", "Y", "Z"]] + vsize = np.array([(max_-min_)/s for (min_, max_), s in zip(ext, size)]) + return vsize + + +def extract_training_data(imaris_file, output_folder): + with h5py.File(imaris_file, "r") as f: + data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:] + points = f["/Scene/Content/Points0/CoordsXYZR"][:] + points = points[:, :-1] + points = points[:, ::-1] + + # TODO crop the data to the original shape. + # Can we just crop the zero-padding ?! + crop_box = np.where(data != 0) + crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box) + data = data[crop_box] + print(data.shape) + + # Scale the points to match the image dimensions. + voxel_size = get_voxel_size(imaris_file) + points /= voxel_size[None] + + if output_folder is None: + v = napari.Viewer() + v.add_image(data) + v.add_points(points) + v.title = os.path.basename(imaris_file) + napari.run() + else: + image_folder = os.path.join(output_folder, "images") + os.makedirs(image_folder, exist_ok=True) + + label_folder = os.path.join(output_folder, "labels") + os.makedirs(label_folder, exist_ok=True) + + fname = Path(imaris_file).stem + image_file = os.path.join(image_folder, f"{fname}.zarr") + label_file = os.path.join(label_folder, f"{fname}.csv") + + coords = pd.DataFrame(points, columns=["axis-0", "axis-1", "axis-2"]) + coords.to_csv(label_file, index=False) + + f = zarr.open(image_file, "a") + f.create_dataset("raw", data=data) + + +# Files that look good for training: +# - 4.1L_apex_IHCribboncount_Z.ims +# - 4.1L_base_IHCribbons_Z.ims +# - 4.1L_mid_IHCribboncount_Z.ims +# - 4.2R_apex_IHCribboncount_Z.ims +# - 4.2R_apex_IHCribboncount_Z.ims +# - 6.2R_apex_IHCribboncount_Z.ims (very small crop) +# - 6.2R_base_IHCribbons_Z.ims +def main(): + files = sorted(glob("./data/synapse_stains/*.ims")) + for ff in files: + extract_training_data(ff, output_folder="./training_data") + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/run_prediction.py b/scripts/synapse_marker_detection/run_prediction.py new file mode 100644 index 0000000..1195f31 --- /dev/null +++ b/scripts/synapse_marker_detection/run_prediction.py @@ -0,0 +1,53 @@ +import argparse +import os + +import pandas as pd +import numpy as np +import zarr + +from elf.parallel.local_maxima import find_local_maxima +from flamingo_tools.segmentation.unet_prediction import prediction_impl + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-o", "--output_folder", required=True) + parser.add_argument("-m", "--model", required=True) + parser.add_argument("-k", "--input_key", default=None) + args = parser.parse_args() + + block_shape = (64, 256, 256) + halo = (16, 64, 64) + + # Skip existing prediction, which is saved in output_folder/predictions.zarr + skip_prediction = False + output_path = os.path.join(args.output_folder, "predictions.zarr") + prediction_key = "prediction" + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): + skip_prediction = True + + if not skip_prediction: + prediction_impl( + args.input, args.input_key, args.output_folder, args.model, + scale=None, block_shape=block_shape, halo=halo, + apply_postprocessing=False, output_channels=1, + ) + + detection_path = os.path.join(args.output_folder, "synapse_detection.tsv") + if not os.path.exists(detection_path): + input_ = zarr.open(output_path, "r")[prediction_key] + detections = find_local_maxima( + input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16, + ) + # Save the result in mobie compatible format. + detections = np.concatenate( + [np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1 + ) + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) + detections.to_csv(detection_path, index=False, sep="\t") + + +if __name__ == "__main__": + main() diff --git a/scripts/synapse_marker_detection/train_synapse_detection.py b/scripts/synapse_marker_detection/train_synapse_detection.py new file mode 100644 index 0000000..2a7d6af --- /dev/null +++ b/scripts/synapse_marker_detection/train_synapse_detection.py @@ -0,0 +1,88 @@ +import os +import sys + +from detection_dataset import DetectionDataset + +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") + +from utils.training.training import supervised_training # noqa + +ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v1" # noqa +TRAIN_ROOT = os.path.join(ROOT, "images") +LABEL_ROOT = os.path.join(ROOT, "labels") + + +def get_paths(split): + file_names = [ + "4.1L_apex_IHCribboncount_Z", + "4.1L_base_IHCribbons_Z", + "4.1L_mid_IHCribboncount_Z", + "4.2R_apex_IHCribboncount_Z", + "4.2R_apex_IHCribboncount_Z", + "6.2R_apex_IHCribboncount_Z", + "6.2R_base_IHCribbons_Z", + ] + image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names] + label_paths = [os.path.join(LABEL_ROOT, f"{fname}.csv") for fname in file_names] + + if split == "train": + image_paths = image_paths[:-1] + label_paths = label_paths[:-1] + else: + image_paths = image_paths[-1:] + label_paths = label_paths[-1:] + + return image_paths, label_paths + + +# TODO maybe add a sampler for the label data +def train(): + + model_name = "synapse_detection_v1" + + train_paths, train_label_paths = get_paths("train") + val_paths, val_label_paths = get_paths("val") + # We need to give the paths for the test loader, although it's never used. + test_paths, test_label_paths = val_paths, val_label_paths + + print("Start training with:") + print(len(train_paths), "tomograms for training") + print(len(val_paths), "tomograms for validation") + + patch_shape = [40, 112, 112] + batch_size = 32 + check = False + + supervised_training( + name=model_name, + train_paths=train_paths, + train_label_paths=train_label_paths, + val_paths=val_paths, + val_label_paths=val_label_paths, + raw_key="raw", + patch_shape=patch_shape, batch_size=batch_size, + check=check, + lr=1e-4, + n_iterations=int(5e4), + out_channels=1, + augmentations=None, + eps=1e-5, + sigma=1, + lower_bound=None, + upper_bound=None, + test_paths=test_paths, + test_label_paths=test_label_paths, + # save_root="", + dataset_class=DetectionDataset, + n_samples_train=3200, + n_samples_val=160, + ) + + +def main(): + train() + + +if __name__ == "__main__": + main()