diff --git a/scripts/README.md b/scripts/README.md index 6bfbbd0..0673875 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -13,6 +13,7 @@ conda install -c conda-forge mobie_utils ## Training Contains the scripts for training a U-Net that predicts foreground probabilties and normalized object distances. +It also contains a documentation for how to run training on new annotated data. ## Prediction diff --git a/scripts/data_transfer/README.md b/scripts/data_transfer/README.md index edc4a2c..766abca 100644 --- a/scripts/data_transfer/README.md +++ b/scripts/data_transfer/README.md @@ -33,3 +33,10 @@ Try to automate via https://github.com/jborean93/smbprotocol see `sync_smb.py` f For transfering back MoBIE results. ... + +# Data Transfer Huisken + +See "Transfer via smbclient" above: +``` +smbclient \\\\wfs-biologie-spezial.top.gwdg.de\\UBM1-all\$\\ -U GWDG\\pape41 +``` diff --git a/scripts/training/README.md b/scripts/training/README.md new file mode 100644 index 0000000..becdaba --- /dev/null +++ b/scripts/training/README.md @@ -0,0 +1,14 @@ +# 3D U-Net Training for Cochlea Data + +This folder contains the scripts for training a 3D U-Net for cell segmentation in the cochlea data. +It contains two relevant scripts: +- `check_training_data.py`, which visualizes the training data and annotations in napari. +- `train_distance_unet.py`, which trains the 3D U-Net. + +Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training. +They will load the **image data** according to the following rules: +- Files with the ending `_annotations.tif` or `_cp_masks.tif` will not be considered as image data. +- The other files will be considered as image data, if a corresponding file with ending `_annotations.tif` can be found. If it cannot be found the file will be excluded; the scripts will print the name of all files being excluded. + +The training script will save the trained model in `checkpoints/cochlea_distance_unet_`, e.g. `checkpoints/cochlea_distance_unet_20250115`. +For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`. diff --git a/scripts/training/check_training_data.py b/scripts/training/check_training_data.py index e0149d8..d6d7c67 100644 --- a/scripts/training/check_training_data.py +++ b/scripts/training/check_training_data.py @@ -1,42 +1,57 @@ +import argparse import os -from glob import glob import imageio.v3 as imageio import napari import numpy as np -root = "/home/pape/Work/data/moser/lightsheet" +from train_distance_unet import get_image_and_label_paths +from tqdm import tqdm +# Root folder on my laptop. +# This is just for convenience, so that I don't have to pass +# the root argument during development. +ROOT_CP = "/home/pape/Work/data/moser/lightsheet" -def check_visually(check_downsampled=False): - if check_downsampled: - images = sorted(glob(os.path.join(root, "images_s2", "*.tif"))) - masks = sorted(glob(os.path.join(root, "masks_s2", "*.tif"))) - else: - images = sorted(glob(os.path.join(root, "images", "*.tif"))) - masks = sorted(glob(os.path.join(root, "masks", "*.tif"))) - assert len(images) == len(masks) - for im, mask in zip(images, masks): - print(im) +def check_visually(images, labels): + for im, label in tqdm(zip(images, labels), total=len(images)): vol = imageio.imread(im) - seg = imageio.imread(mask).astype("uint32") + seg = imageio.imread(label).astype("uint32") v = napari.Viewer() - v.add_image(vol) - v.add_labels(seg) + v.add_image(vol, name="pv-channel") + v.add_labels(seg, name="annotations") + folder, name = os.path.split(im) + folder = os.path.basename(folder) + v.title = f"{folder}/{name}" napari.run() -def check_labels(): - masks = sorted(glob(os.path.join(root, "masks", "*.tif"))) - for mask_path in masks: - labels = imageio.imread(mask_path) +def check_labels(images, labels): + for label_path in labels: + labels = imageio.imread(label_path) n_labels = len(np.unique(labels)) - print(mask_path, n_labels) + print(label_path, n_labels) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--root", "-i", help="The root folder with the annotated training crops.", + default=ROOT_CP, + ) + parser.add_argument("--check_labels", "-l", action="store_true") + args = parser.parse_args() + root = args.root + + images, labels = get_image_and_label_paths(root) + + check_visually(images, labels) + if args.check_labels: + check_labels(images, labels) if __name__ == "__main__": - check_visually(True) - # check_labels() + main() diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 2b6ee61..fbddb2d 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -1,15 +1,45 @@ +import argparse import os +from datetime import datetime from glob import glob import torch_em - from torch_em.model import UNet3d -# DATA_ROOT = "/home/pape/Work/data/moser/lightsheet" -DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet" +ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training" + + +def get_image_and_label_paths(root): + exclude_names = ["annotations", "cp_masks"] + all_image_paths = sorted(glob(os.path.join(root, "**/**.tif"), recursive=True)) + all_image_paths = [ + path for path in all_image_paths if not any(exclude in path for exclude in exclude_names) + ] + + image_paths, label_paths = [], [] + label_extensions = ["_annotations.tif"] + for path in all_image_paths: + folder, fname = os.path.split(path) + fname = os.path.splitext(fname)[0] + label_path = None + for ext in label_extensions: + candidate_label_path = os.path.join(folder, f"{fname}{ext}") + if os.path.exists(candidate_label_path): + label_path = candidate_label_path + break + + if label_path is None: + print("Did not find annotations for", path) + print("This image will not be used for training.") + else: + image_paths.append(path) + label_paths.append(label_path) + + assert len(image_paths) == len(label_paths) + return image_paths, label_paths -def get_paths(image_paths, label_paths, split, filter_empty): +def select_paths(image_paths, label_paths, split, filter_empty): if filter_empty: image_paths = [imp for imp in image_paths if "empty" not in imp] label_paths = [imp for imp in label_paths if "empty" not in imp] @@ -17,38 +47,26 @@ def get_paths(image_paths, label_paths, split, filter_empty): n_files = len(image_paths) - train_fraction = 0.8 - val_fraction = 0.1 + train_fraction = 0.85 n_train = int(train_fraction * n_files) - n_val = int(val_fraction * n_files) if split == "train": image_paths = image_paths[:n_train] label_paths = label_paths[:n_train] elif split == "val": - image_paths = image_paths[n_train:(n_train + n_val)] - label_paths = label_paths[n_train:(n_train + n_val)] + image_paths = image_paths[n_train:] + label_paths = label_paths[n_train:] return image_paths, label_paths -def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"]): - image_paths, label_paths = [], [] - - if "default" in train_on: - all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images", "*.tif"))) - all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks", "*.tif"))) - this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty) - image_paths.extend(this_image_paths) - label_paths.extend(this_label_paths) +def get_loader(root, split, patch_shape, batch_size, filter_empty): + image_paths, label_paths = get_image_and_label_paths(root) + this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty) - if "downsampled" in train_on: - all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images_s2", "*.tif"))) - all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks_s2", "*.tif"))) - this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty) - image_paths.extend(this_image_paths) - label_paths.extend(this_label_paths) + assert len(this_image_paths) == len(this_label_paths) + assert len(this_image_paths) > 0 label_transform = torch_em.transform.label.PerObjectDistanceTransform( distances=True, boundary_distances=True, foreground=True, @@ -59,7 +77,7 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default" elif split == "val": n_samples = 20 * batch_size - sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.95) + sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8) loader = torch_em.default_segmentation_loader( raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None, batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform, @@ -69,26 +87,45 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default" return loader -def main(check_loaders=False): - # Parameters for training: +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--root", "-i", help="The root folder with the annotated training crops.", + default=ROOT_CLUSTER, + ) + parser.add_argument( + "--batch_size", "-b", help="The batch size for training. Set to 8 by default." + "You may need to choose a smaller batch size to train on yoru GPU.", + default=8, type=int, + ) + parser.add_argument( + "--check_loaders", "-l", action="store_true", + help="Visualize the data loader output instead of starting a training run." + ) + parser.add_argument( + "--filter_empty", "-f", action="store_true", + help="Whether to exclude blocks with empty annotations from the training process." + ) + parser.add_argument( + "--name", help="Optional name for the model to be trained. If not given the current date is used." + ) + args = parser.parse_args() + root = args.root + batch_size = args.batch_size + check_loaders = args.check_loaders + filter_empty = args.filter_empty + run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name + + # Parameters for training on A100. n_iterations = 1e5 - batch_size = 8 - filter_empty = False - train_on = ["downsampled"] - # train_on = ["downsampled", "default"] - - patch_shape = (32, 128, 128) if "downsampled" in train_on else (64, 128, 128) + patch_shape = (64, 128, 128) # The U-Net. model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid") # Create the training loader with train and val set. - train_loader = get_loader( - "train", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on - ) - val_loader = get_loader( - "val", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on - ) + train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty) + val_loader = get_loader(root, "val", patch_shape, batch_size, filter_empty=filter_empty) if check_loaders: from torch_em.util.debug import check_loader @@ -99,12 +136,7 @@ def main(check_loaders=False): loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True) # Create the trainer. - name = "cochlea_distance_unet" - if filter_empty: - name += "-filter-empty" - if train_on == ["downsampled"]: - name += "-train-downsampled" - + name = f"cochlea_distance_unet_{run_name}" trainer = torch_em.default_segmentation_trainer( name=name, model=model, @@ -123,4 +155,4 @@ def main(check_loaders=False): if __name__ == "__main__": - main(check_loaders=False) + main()