Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ conda install -c conda-forge mobie_utils
## Training

Contains the scripts for training a U-Net that predicts foreground probabilties and normalized object distances.
It also contains a documentation for how to run training on new annotated data.


## Prediction
Expand Down
7 changes: 7 additions & 0 deletions scripts/data_transfer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ Try to automate via https://github.com/jborean93/smbprotocol see `sync_smb.py` f

For transfering back MoBIE results.
...

# Data Transfer Huisken

See "Transfer via smbclient" above:
```
smbclient \\\\wfs-biologie-spezial.top.gwdg.de\\UBM1-all\$\\ -U GWDG\\pape41
```
14 changes: 14 additions & 0 deletions scripts/training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# 3D U-Net Training for Cochlea Data

This folder contains the scripts for training a 3D U-Net for cell segmentation in the cochlea data.
It contains two relevant scripts:
- `check_training_data.py`, which visualizes the training data and annotations in napari.
- `train_distance_unet.py`, which trains the 3D U-Net.

Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
They will load the **image data** according to the following rules:
- Files with the ending `_annotations.tif` or `_cp_masks.tif` will not be considered as image data.
- The other files will be considered as image data, if a corresponding file with ending `_annotations.tif` can be found. If it cannot be found the file will be excluded; the scripts will print the name of all files being excluded.

The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.
59 changes: 37 additions & 22 deletions scripts/training/check_training_data.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,57 @@
import argparse
import os
from glob import glob

import imageio.v3 as imageio
import napari
import numpy as np

root = "/home/pape/Work/data/moser/lightsheet"
from train_distance_unet import get_image_and_label_paths
from tqdm import tqdm

# Root folder on my laptop.
# This is just for convenience, so that I don't have to pass
# the root argument during development.
ROOT_CP = "/home/pape/Work/data/moser/lightsheet"

def check_visually(check_downsampled=False):
if check_downsampled:
images = sorted(glob(os.path.join(root, "images_s2", "*.tif")))
masks = sorted(glob(os.path.join(root, "masks_s2", "*.tif")))
else:
images = sorted(glob(os.path.join(root, "images", "*.tif")))
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
assert len(images) == len(masks)

for im, mask in zip(images, masks):
print(im)
def check_visually(images, labels):
for im, label in tqdm(zip(images, labels), total=len(images)):

vol = imageio.imread(im)
seg = imageio.imread(mask).astype("uint32")
seg = imageio.imread(label).astype("uint32")

v = napari.Viewer()
v.add_image(vol)
v.add_labels(seg)
v.add_image(vol, name="pv-channel")
v.add_labels(seg, name="annotations")
folder, name = os.path.split(im)
folder = os.path.basename(folder)
v.title = f"{folder}/{name}"
napari.run()


def check_labels():
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
for mask_path in masks:
labels = imageio.imread(mask_path)
def check_labels(images, labels):
for label_path in labels:
labels = imageio.imread(label_path)
n_labels = len(np.unique(labels))
print(mask_path, n_labels)
print(label_path, n_labels)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CP,
)
parser.add_argument("--check_labels", "-l", action="store_true")
args = parser.parse_args()
root = args.root

images, labels = get_image_and_label_paths(root)

check_visually(images, labels)
if args.check_labels:
check_labels(images, labels)


if __name__ == "__main__":
check_visually(True)
# check_labels()
main()
124 changes: 78 additions & 46 deletions scripts/training/train_distance_unet.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,72 @@
import argparse
import os
from datetime import datetime
from glob import glob

import torch_em

from torch_em.model import UNet3d

# DATA_ROOT = "/home/pape/Work/data/moser/lightsheet"
DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"


def get_image_and_label_paths(root):
exclude_names = ["annotations", "cp_masks"]
all_image_paths = sorted(glob(os.path.join(root, "**/**.tif"), recursive=True))
all_image_paths = [
path for path in all_image_paths if not any(exclude in path for exclude in exclude_names)
]

image_paths, label_paths = [], []
label_extensions = ["_annotations.tif"]
for path in all_image_paths:
folder, fname = os.path.split(path)
fname = os.path.splitext(fname)[0]
label_path = None
for ext in label_extensions:
candidate_label_path = os.path.join(folder, f"{fname}{ext}")
if os.path.exists(candidate_label_path):
label_path = candidate_label_path
break

if label_path is None:
print("Did not find annotations for", path)
print("This image will not be used for training.")
else:
image_paths.append(path)
label_paths.append(label_path)

assert len(image_paths) == len(label_paths)
return image_paths, label_paths


def get_paths(image_paths, label_paths, split, filter_empty):
def select_paths(image_paths, label_paths, split, filter_empty):
if filter_empty:
image_paths = [imp for imp in image_paths if "empty" not in imp]
label_paths = [imp for imp in label_paths if "empty" not in imp]
assert len(image_paths) == len(label_paths)

n_files = len(image_paths)

train_fraction = 0.8
val_fraction = 0.1
train_fraction = 0.85

n_train = int(train_fraction * n_files)
n_val = int(val_fraction * n_files)
if split == "train":
image_paths = image_paths[:n_train]
label_paths = label_paths[:n_train]

elif split == "val":
image_paths = image_paths[n_train:(n_train + n_val)]
label_paths = label_paths[n_train:(n_train + n_val)]
image_paths = image_paths[n_train:]
label_paths = label_paths[n_train:]

return image_paths, label_paths


def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"]):
image_paths, label_paths = [], []

if "default" in train_on:
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images", "*.tif")))
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks", "*.tif")))
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
image_paths.extend(this_image_paths)
label_paths.extend(this_label_paths)
def get_loader(root, split, patch_shape, batch_size, filter_empty):
image_paths, label_paths = get_image_and_label_paths(root)
this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty)

if "downsampled" in train_on:
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images_s2", "*.tif")))
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks_s2", "*.tif")))
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
image_paths.extend(this_image_paths)
label_paths.extend(this_label_paths)
assert len(this_image_paths) == len(this_label_paths)
assert len(this_image_paths) > 0

label_transform = torch_em.transform.label.PerObjectDistanceTransform(
distances=True, boundary_distances=True, foreground=True,
Expand All @@ -59,7 +77,7 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
elif split == "val":
n_samples = 20 * batch_size

sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.95)
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
loader = torch_em.default_segmentation_loader(
raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None,
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
Expand All @@ -69,26 +87,45 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
return loader


def main(check_loaders=False):
# Parameters for training:
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--root", "-i", help="The root folder with the annotated training crops.",
default=ROOT_CLUSTER,
)
parser.add_argument(
"--batch_size", "-b", help="The batch size for training. Set to 8 by default."
"You may need to choose a smaller batch size to train on yoru GPU.",
default=8, type=int,
)
parser.add_argument(
"--check_loaders", "-l", action="store_true",
help="Visualize the data loader output instead of starting a training run."
)
parser.add_argument(
"--filter_empty", "-f", action="store_true",
help="Whether to exclude blocks with empty annotations from the training process."
)
parser.add_argument(
"--name", help="Optional name for the model to be trained. If not given the current date is used."
)
args = parser.parse_args()
root = args.root
batch_size = args.batch_size
check_loaders = args.check_loaders
filter_empty = args.filter_empty
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name

# Parameters for training on A100.
n_iterations = 1e5
batch_size = 8
filter_empty = False
train_on = ["downsampled"]
# train_on = ["downsampled", "default"]

patch_shape = (32, 128, 128) if "downsampled" in train_on else (64, 128, 128)
patch_shape = (64, 128, 128)

# The U-Net.
model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid")

# Create the training loader with train and val set.
train_loader = get_loader(
"train", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
)
val_loader = get_loader(
"val", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
)
train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty)
val_loader = get_loader(root, "val", patch_shape, batch_size, filter_empty=filter_empty)

if check_loaders:
from torch_em.util.debug import check_loader
Expand All @@ -99,12 +136,7 @@ def main(check_loaders=False):
loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True)

# Create the trainer.
name = "cochlea_distance_unet"
if filter_empty:
name += "-filter-empty"
if train_on == ["downsampled"]:
name += "-train-downsampled"

name = f"cochlea_distance_unet_{run_name}"
trainer = torch_em.default_segmentation_trainer(
name=name,
model=model,
Expand All @@ -123,4 +155,4 @@ def main(check_loaders=False):


if __name__ == "__main__":
main(check_loaders=False)
main()
Loading