Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flamingo_tools/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .util import get_3d_model, get_supervised_loader
from .mean_teacher_training import mean_teacher_training
219 changes: 219 additions & 0 deletions flamingo_tools/training/mean_teacher_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import os
from typing import Optional, Tuple

import torch
import torch_em
import torch_em.self_training as self_training
from torchvision import transforms

from .util import get_supervised_loader, get_3d_model


def weak_augmentations(p: float = 0.75) -> callable:
"""The weak augmentations used in the unsupervised data loader.

Args:
p: The probability for applying one of the augmentations.

Returns:
The transformation function applying the augmentation.
"""
norm = torch_em.transform.raw.standardize
aug = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0, 0.15), clip_kwargs=False)], p=p
),
])
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)


def get_unsupervised_loader(
data_paths: Tuple[str],
raw_key: Optional[str],
patch_shape: Tuple[int, int, int],
batch_size: int,
n_samples: Optional[int],
) -> torch.utils.data.DataLoader:
"""Get a dataloader for unsupervised segmentation training.

Args:
data_paths: The filepaths to the hdf5 files containing the training data.
raw_key: The key that holds the raw data inside of the hdf5.
patch_shape: The patch shape used for a training example.
In order to run 2d training pass a patch shape with a singleton in the z-axis,
e.g. 'patch_shape = [1, 512, 512]'.
batch_size: The batch size for training.
n_samples: The number of samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for training.

Returns:
The PyTorch dataloader.
"""
raw_transform = torch_em.transform.get_raw_transform()
transform = torch_em.transform.get_augmentations(ndim=3)

if n_samples is None:
n_samples_per_ds = None
else:
n_samples_per_ds = int(n_samples / len(data_paths))

augmentations = (weak_augmentations(), weak_augmentations())
datasets = [
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
augmentations=augmentations, ndim=3, n_samples=n_samples_per_ds)
for path in data_paths
]
ds = torch.utils.data.ConcatDataset(datasets)

# num_workers = 4 * batch_size
num_workers = batch_size
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
return loader


def mean_teacher_training(
name: str,
unsupervised_train_paths: Tuple[str],
unsupervised_val_paths: Tuple[str],
patch_shape: Tuple[int, int, int],
save_root: Optional[str] = None,
source_checkpoint: Optional[str] = None,
supervised_train_image_paths: Optional[Tuple[str]] = None,
supervised_val_image_paths: Optional[Tuple[str]] = None,
supervised_train_label_paths: Optional[Tuple[str]] = None,
supervised_val_label_paths: Optional[Tuple[str]] = None,
confidence_threshold: float = 0.9,
raw_key: Optional[str] = None,
raw_key_supervised: Optional[str] = None,
label_key: Optional[str] = None,
batch_size: int = 1,
lr: float = 1e-4,
n_iterations: int = int(1e4),
n_samples_train: Optional[int] = None,
n_samples_val: Optional[int] = None,
sampler: Optional[callable] = None,
) -> None:
"""This function implements network training with a mean teacher approach.

It can be used for semi-supervised learning, unsupervised domain adaptation and supervised domain adaptation.
These different training modes can be used as this:
- semi-supervised learning: pass 'unsupervised_train/val_paths' and 'supervised_train/val_paths'.
- unsupervised domain adaptation: pass 'unsupervised_train/val_paths' and 'source_checkpoint'.
- supervised domain adaptation: pass 'unsupervised_train/val_paths', 'supervised_train/val_paths', 'source_checkpoint'.

Args:
name: The name for the checkpoint to be trained.
unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
for the training data in the target domain.
This training data is used for unsupervised learning, so it does not require labels.
unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
for the validation data in the target domain.
This validation data is used for unsupervised learning, so it does not require labels.
patch_shape: The patch shape used for a training example.
In order to run 2d training pass a patch shape with a singleton in the z-axis,
e.g. 'patch_shape = [1, 512, 512]'.
save_root: Folder where the checkpoint will be saved.
source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
be given in order to provide training data from the source domain.
supervised_train_image_paths: Paths to the files for the supervised image data; training split.
This training data is optional. If given, it also requires labels.
supervised_val_image_paths: Ppaths to the files for the supervised image data; validation split.
This validation data is optional. If given, it also requires labels.
supervised_train_label_paths: Filepaths to the files for the supervised label masks; training split.
This training data is optional.
supervised_val_label_paths: Filepaths to the files for the supervised label masks; validation split.
This tvalidation data is optional.
confidence_threshold: The threshold for filtering data in the unsupervised loss.
The label filtering is done based on the uncertainty of network predictions, and only
the data with higher certainty than this threshold is used for training.
raw_key: The key that holds the raw data inside of the hdf5 or similar files;
for the unsupervised training data. Set to None for tifs.
raw_key_supervised: The key that holds the raw data inside of the hdf5 or similar files;
for the supervised training data. Set to None for tifs.
label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if `supervised_train_label_paths` and `supervised_val_label_paths` are given.
Set to None for tifs.
batch_size: The batch size for training.
lr: The initial learning rate.
n_iterations: The number of iterations to train for.
n_samples_train: The number of train samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
""" # noqa
assert (supervised_train_image_paths is None) == (supervised_val_image_paths is None)

if source_checkpoint is None:
# Training from scratch only makes sense if we have supervised training data
# that's why we have the assertion here.
assert supervised_train_image_paths is not None
model = get_3d_model(out_channels=3)
reinit_teacher = True
else:
print("Mean teacehr training initialized from source model:", source_checkpoint)
if os.path.isdir(source_checkpoint):
model = torch_em.util.load_model(source_checkpoint)
else:
model = torch.load(source_checkpoint, weights_only=False)
reinit_teacher = False

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

# self training functionality
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold, mask_channel=0)
loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

unsupervised_train_loader = get_unsupervised_loader(
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
)
unsupervised_val_loader = get_unsupervised_loader(
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
)

if supervised_train_image_paths is not None:
supervised_train_loader = get_supervised_loader(
supervised_train_image_paths, supervised_train_label_paths,
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_train,
image_key=raw_key_supervised, label_key=label_key,
)
supervised_val_loader = get_supervised_loader(
supervised_val_image_paths, supervised_val_label_paths,
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_val,
image_key=raw_key_supervised, label_key=label_key,
)
else:
supervised_train_loader = None
supervised_val_loader = None

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
trainer = self_training.MeanTeacherTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
supervised_train_loader=supervised_train_loader,
unsupervised_train_loader=unsupervised_train_loader,
supervised_val_loader=supervised_val_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=self_training.SelfTrainingTensorboardLogger,
mixed_precision=True,
log_image_interval=100,
compile_model=False,
device=device,
reinit_teacher=reinit_teacher,
save_root=save_root,
sampler=sampler,
)
trainer.fit(n_iterations)
57 changes: 57 additions & 0 deletions flamingo_tools/training/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional, Sequence, Tuple

import torch.nn as nn
import torch_em
from torch_em.model import UNet3d
from torch.utils.data import DataLoader


def get_3d_model(out_channels: int = 3, final_activation: Optional[str] = "Sigmoid") -> nn.Module:
"""Get a 3D U-Net for segmentation or detection tasks.

Args:
out_channels: The number of output channels of the network.
final_activation: The activation applied to the last layer.
Set to 'None' for no activation; by default this applies a Sigmoid activation.

Returns:
The 3D U-Net.
"""
return UNet3d(in_channels=1, out_channels=out_channels, initial_features=32, final_activation=final_activation)


def get_supervised_loader(
image_paths: Sequence[str],
label_paths: Sequence[str],
patch_shape: Tuple[int, int, int],
batch_size: int,
image_key: Optional[str] = None,
label_key: Optional[str] = None,
n_samples: Optional[int] = None,
) -> DataLoader:
"""Get a data loader for a supervised segmentation task.

Args:
image_paths: The filepaths to the image data. These can be stored either in tif or in hdf5/zarr/n5.
image_paths: The filepaths to the label masks. These can be stored either in tif or in hdf5/zarr/n5.
patch_shape: The 3D patch shape for training.
batch_Size: The batch size for training.
image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data.
image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data.
n_samples: The number of samples to use for training.

Returns:
The data loader.
"""
assert len(image_paths) == len(label_paths)
assert len(image_paths) > 0
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
distances=True, boundary_distances=True, foreground=True,
)
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
loader = torch_em.default_segmentation_loader(
raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key,
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler
)
return loader
86 changes: 86 additions & 0 deletions scripts/training/ihc_semi_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
from glob import glob

import torch
from torch_em.util import load_model
from flamingo_tools.training import mean_teacher_training


def get_paths():
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-05-IHC_semi-supervised"
annotated_folders = ["annotated_train", "empty"]
train_image = []
train_label = []
for folder in annotated_folders:
with os.scandir(os.path.join(root, folder)) as direc:
for entry in direc:
if "annotations" not in entry.name and entry.is_file():
basename = os.path.basename(entry.name)
name_no_extension = ".".join(basename.split(".")[:-1])
label_name = name_no_extension + "_annotations.tif"
train_image.extend(glob(os.path.join(root, folder, entry.name)))
train_label.extend(glob(os.path.join(root, folder, label_name)))

annotated_folders = ["annotated_val"]
val_image = []
val_label = []
for folder in annotated_folders:
with os.scandir(os.path.join(root, folder)) as direc:
for entry in direc:
if "annotations" not in entry.name and entry.is_file():
basename = os.path.basename(entry.name)
name_no_extension = ".".join(basename.split(".")[:-1])
label_name = name_no_extension + "_annotations.tif"
val_image.extend(glob(os.path.join(root, folder, entry.name)))
val_label.extend(glob(os.path.join(root, folder, label_name)))

domain_folders = ["domain_Aleyna", "domain_Lennart"]
paths_domain = []
for folder in domain_folders:
paths_domain.extend(glob(os.path.join(root, folder, "*.tif")))

return train_image, train_label, val_image, val_label, paths_domain[:-2], paths_domain[-2:]


def run_training(name):
patch_shape = (64, 128, 128)
batch_size = 8

super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths()

mean_teacher_training(
name=name,
unsupervised_train_paths=unsuper_train,
unsupervised_val_paths=unsuper_val,
patch_shape=patch_shape,
supervised_train_image_paths=super_train_img,
supervised_val_image_paths=super_val_img,
supervised_train_label_paths=super_train_label,
supervised_val_label_paths=super_val_label,
batch_size=batch_size,
n_iterations=int(1e5),
n_samples_train=1000,
n_samples_val=80,
)


def export_model(name, export_path):
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
torch.save(model, export_path)


def main():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--export_path")
args = parser.parse_args()
name = "IHC_semi-supervised_2025-05-22"
if args.export_path is None:
run_training(name)
else:
export_model(name, args.export_path)


if __name__ == "__main__":
main()
Loading