diff --git a/flamingo_tools/training/__init__.py b/flamingo_tools/training/__init__.py new file mode 100644 index 0000000..e412448 --- /dev/null +++ b/flamingo_tools/training/__init__.py @@ -0,0 +1,2 @@ +from .util import get_3d_model, get_supervised_loader +from .mean_teacher_training import mean_teacher_training diff --git a/flamingo_tools/training/mean_teacher_training.py b/flamingo_tools/training/mean_teacher_training.py new file mode 100644 index 0000000..7234541 --- /dev/null +++ b/flamingo_tools/training/mean_teacher_training.py @@ -0,0 +1,219 @@ +import os +from typing import Optional, Tuple + +import torch +import torch_em +import torch_em.self_training as self_training +from torchvision import transforms + +from .util import get_supervised_loader, get_3d_model + + +def weak_augmentations(p: float = 0.75) -> callable: + """The weak augmentations used in the unsupervised data loader. + + Args: + p: The probability for applying one of the augmentations. + + Returns: + The transformation function applying the augmentation. + """ + norm = torch_em.transform.raw.standardize + aug = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0, 0.15), clip_kwargs=False)], p=p + ), + ]) + return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) + + +def get_unsupervised_loader( + data_paths: Tuple[str], + raw_key: Optional[str], + patch_shape: Tuple[int, int, int], + batch_size: int, + n_samples: Optional[int], +) -> torch.utils.data.DataLoader: + """Get a dataloader for unsupervised segmentation training. + + Args: + data_paths: The filepaths to the hdf5 files containing the training data. + raw_key: The key that holds the raw data inside of the hdf5. + patch_shape: The patch shape used for a training example. + In order to run 2d training pass a patch shape with a singleton in the z-axis, + e.g. 'patch_shape = [1, 512, 512]'. + batch_size: The batch size for training. + n_samples: The number of samples per epoch. By default this will be estimated + based on the patch_shape and size of the volumes used for training. + + Returns: + The PyTorch dataloader. + """ + raw_transform = torch_em.transform.get_raw_transform() + transform = torch_em.transform.get_augmentations(ndim=3) + + if n_samples is None: + n_samples_per_ds = None + else: + n_samples_per_ds = int(n_samples / len(data_paths)) + + augmentations = (weak_augmentations(), weak_augmentations()) + datasets = [ + torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, + augmentations=augmentations, ndim=3, n_samples=n_samples_per_ds) + for path in data_paths + ] + ds = torch.utils.data.ConcatDataset(datasets) + + # num_workers = 4 * batch_size + num_workers = batch_size + loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) + return loader + + +def mean_teacher_training( + name: str, + unsupervised_train_paths: Tuple[str], + unsupervised_val_paths: Tuple[str], + patch_shape: Tuple[int, int, int], + save_root: Optional[str] = None, + source_checkpoint: Optional[str] = None, + supervised_train_image_paths: Optional[Tuple[str]] = None, + supervised_val_image_paths: Optional[Tuple[str]] = None, + supervised_train_label_paths: Optional[Tuple[str]] = None, + supervised_val_label_paths: Optional[Tuple[str]] = None, + confidence_threshold: float = 0.9, + raw_key: Optional[str] = None, + raw_key_supervised: Optional[str] = None, + label_key: Optional[str] = None, + batch_size: int = 1, + lr: float = 1e-4, + n_iterations: int = int(1e4), + n_samples_train: Optional[int] = None, + n_samples_val: Optional[int] = None, + sampler: Optional[callable] = None, +) -> None: + """This function implements network training with a mean teacher approach. + + It can be used for semi-supervised learning, unsupervised domain adaptation and supervised domain adaptation. + These different training modes can be used as this: + - semi-supervised learning: pass 'unsupervised_train/val_paths' and 'supervised_train/val_paths'. + - unsupervised domain adaptation: pass 'unsupervised_train/val_paths' and 'source_checkpoint'. + - supervised domain adaptation: pass 'unsupervised_train/val_paths', 'supervised_train/val_paths', 'source_checkpoint'. + + Args: + name: The name for the checkpoint to be trained. + unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats + for the training data in the target domain. + This training data is used for unsupervised learning, so it does not require labels. + unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats + for the validation data in the target domain. + This validation data is used for unsupervised learning, so it does not require labels. + patch_shape: The patch shape used for a training example. + In order to run 2d training pass a patch shape with a singleton in the z-axis, + e.g. 'patch_shape = [1, 512, 512]'. + save_root: Folder where the checkpoint will be saved. + source_checkpoint: Checkpoint to the initial model trained on the source domain. + This is used to initialize the teacher model. + If the checkpoint is not given, then both student and teacher model are initialized + from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to + be given in order to provide training data from the source domain. + supervised_train_image_paths: Paths to the files for the supervised image data; training split. + This training data is optional. If given, it also requires labels. + supervised_val_image_paths: Ppaths to the files for the supervised image data; validation split. + This validation data is optional. If given, it also requires labels. + supervised_train_label_paths: Filepaths to the files for the supervised label masks; training split. + This training data is optional. + supervised_val_label_paths: Filepaths to the files for the supervised label masks; validation split. + This tvalidation data is optional. + confidence_threshold: The threshold for filtering data in the unsupervised loss. + The label filtering is done based on the uncertainty of network predictions, and only + the data with higher certainty than this threshold is used for training. + raw_key: The key that holds the raw data inside of the hdf5 or similar files; + for the unsupervised training data. Set to None for tifs. + raw_key_supervised: The key that holds the raw data inside of the hdf5 or similar files; + for the supervised training data. Set to None for tifs. + label_key: The key that holds the labels inside of the hdf5 files for supervised learning. + This is only required if `supervised_train_label_paths` and `supervised_val_label_paths` are given. + Set to None for tifs. + batch_size: The batch size for training. + lr: The initial learning rate. + n_iterations: The number of iterations to train for. + n_samples_train: The number of train samples per epoch. By default this will be estimated + based on the patch_shape and size of the volumes used for training. + n_samples_val: The number of val samples per epoch. By default this will be estimated + based on the patch_shape and size of the volumes used for validation. + """ # noqa + assert (supervised_train_image_paths is None) == (supervised_val_image_paths is None) + + if source_checkpoint is None: + # Training from scratch only makes sense if we have supervised training data + # that's why we have the assertion here. + assert supervised_train_image_paths is not None + model = get_3d_model(out_channels=3) + reinit_teacher = True + else: + print("Mean teacehr training initialized from source model:", source_checkpoint) + if os.path.isdir(source_checkpoint): + model = torch_em.util.load_model(source_checkpoint) + else: + model = torch.load(source_checkpoint, weights_only=False) + reinit_teacher = False + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold, mask_channel=0) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + unsupervised_train_loader = get_unsupervised_loader( + unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train + ) + unsupervised_val_loader = get_unsupervised_loader( + unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val + ) + + if supervised_train_image_paths is not None: + supervised_train_loader = get_supervised_loader( + supervised_train_image_paths, supervised_train_label_paths, + patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_train, + image_key=raw_key_supervised, label_key=label_key, + ) + supervised_val_loader = get_supervised_loader( + supervised_val_image_paths, supervised_val_label_paths, + patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_val, + image_key=raw_key_supervised, label_key=label_key, + ) + else: + supervised_train_loader = None + supervised_val_loader = None + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + supervised_train_loader=supervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, + supervised_val_loader=supervised_val_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + log_image_interval=100, + compile_model=False, + device=device, + reinit_teacher=reinit_teacher, + save_root=save_root, + sampler=sampler, + ) + trainer.fit(n_iterations) diff --git a/flamingo_tools/training/util.py b/flamingo_tools/training/util.py new file mode 100644 index 0000000..71f7a32 --- /dev/null +++ b/flamingo_tools/training/util.py @@ -0,0 +1,57 @@ +from typing import Optional, Sequence, Tuple + +import torch.nn as nn +import torch_em +from torch_em.model import UNet3d +from torch.utils.data import DataLoader + + +def get_3d_model(out_channels: int = 3, final_activation: Optional[str] = "Sigmoid") -> nn.Module: + """Get a 3D U-Net for segmentation or detection tasks. + + Args: + out_channels: The number of output channels of the network. + final_activation: The activation applied to the last layer. + Set to 'None' for no activation; by default this applies a Sigmoid activation. + + Returns: + The 3D U-Net. + """ + return UNet3d(in_channels=1, out_channels=out_channels, initial_features=32, final_activation=final_activation) + + +def get_supervised_loader( + image_paths: Sequence[str], + label_paths: Sequence[str], + patch_shape: Tuple[int, int, int], + batch_size: int, + image_key: Optional[str] = None, + label_key: Optional[str] = None, + n_samples: Optional[int] = None, +) -> DataLoader: + """Get a data loader for a supervised segmentation task. + + Args: + image_paths: The filepaths to the image data. These can be stored either in tif or in hdf5/zarr/n5. + image_paths: The filepaths to the label masks. These can be stored either in tif or in hdf5/zarr/n5. + patch_shape: The 3D patch shape for training. + batch_Size: The batch size for training. + image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data. + image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data. + n_samples: The number of samples to use for training. + + Returns: + The data loader. + """ + assert len(image_paths) == len(label_paths) + assert len(image_paths) > 0 + label_transform = torch_em.transform.label.PerObjectDistanceTransform( + distances=True, boundary_distances=True, foreground=True, + ) + sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8) + loader = torch_em.default_segmentation_loader( + raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key, + batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform, + n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler + ) + return loader diff --git a/scripts/training/ihc_semi_supervised.py b/scripts/training/ihc_semi_supervised.py new file mode 100644 index 0000000..8c9f722 --- /dev/null +++ b/scripts/training/ihc_semi_supervised.py @@ -0,0 +1,86 @@ +import os +from glob import glob + +import torch +from torch_em.util import load_model +from flamingo_tools.training import mean_teacher_training + + +def get_paths(): + root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-05-IHC_semi-supervised" + annotated_folders = ["annotated_train", "empty"] + train_image = [] + train_label = [] + for folder in annotated_folders: + with os.scandir(os.path.join(root, folder)) as direc: + for entry in direc: + if "annotations" not in entry.name and entry.is_file(): + basename = os.path.basename(entry.name) + name_no_extension = ".".join(basename.split(".")[:-1]) + label_name = name_no_extension + "_annotations.tif" + train_image.extend(glob(os.path.join(root, folder, entry.name))) + train_label.extend(glob(os.path.join(root, folder, label_name))) + + annotated_folders = ["annotated_val"] + val_image = [] + val_label = [] + for folder in annotated_folders: + with os.scandir(os.path.join(root, folder)) as direc: + for entry in direc: + if "annotations" not in entry.name and entry.is_file(): + basename = os.path.basename(entry.name) + name_no_extension = ".".join(basename.split(".")[:-1]) + label_name = name_no_extension + "_annotations.tif" + val_image.extend(glob(os.path.join(root, folder, entry.name))) + val_label.extend(glob(os.path.join(root, folder, label_name))) + + domain_folders = ["domain_Aleyna", "domain_Lennart"] + paths_domain = [] + for folder in domain_folders: + paths_domain.extend(glob(os.path.join(root, folder, "*.tif"))) + + return train_image, train_label, val_image, val_label, paths_domain[:-2], paths_domain[-2:] + + +def run_training(name): + patch_shape = (64, 128, 128) + batch_size = 8 + + super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths() + + mean_teacher_training( + name=name, + unsupervised_train_paths=unsuper_train, + unsupervised_val_paths=unsuper_val, + patch_shape=patch_shape, + supervised_train_image_paths=super_train_img, + supervised_val_image_paths=super_val_img, + supervised_train_label_paths=super_train_label, + supervised_val_label_paths=super_val_label, + batch_size=batch_size, + n_iterations=int(1e5), + n_samples_train=1000, + n_samples_val=80, + ) + + +def export_model(name, export_path): + model = load_model(os.path.join("checkpoints", name), state_key="teacher") + torch.save(model, export_path) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--export_path") + args = parser.parse_args() + name = "IHC_semi-supervised_2025-05-22" + if args.export_path is None: + run_training(name) + else: + export_model(name, args.export_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/training/sgn_domain_adaptation.py b/scripts/training/sgn_domain_adaptation.py new file mode 100644 index 0000000..79bd5ca --- /dev/null +++ b/scripts/training/sgn_domain_adaptation.py @@ -0,0 +1,56 @@ +import os +from glob import glob + +import torch +from torch_em.util import load_model +from flamingo_tools.training import mean_teacher_training + + +def get_paths(): + root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops" + folders = ["fHC", "iDISCO", "microwave-fHC", "microwave-iDISCO"] + paths = [] + for folder in folders: + paths.extend(glob(os.path.join(root, folder, "*.tif"))) + return paths[:-1], paths[-1:] + + +def run_training(name): + patch_shape = (64, 128, 128) + batch_size = 8 + source_checkpoint = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa + + train_paths, val_paths = get_paths() + mean_teacher_training( + name=name, + unsupervised_train_paths=train_paths, + unsupervised_val_paths=val_paths, + patch_shape=patch_shape, + source_checkpoint=source_checkpoint, + batch_size=batch_size, + n_iterations=int(2.5e4), + n_samples_train=1000, + n_samples_val=80, + ) + + +def export_model(name, export_path): + model = load_model(os.path.join("checkpoints", name), state_key="teacher") + torch.save(model, export_path) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--export_path") + args = parser.parse_args() + name = "sgn-adapted-model" + if args.export_path is None: + run_training(name) + else: + export_model(name, args.export_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/training/sgn_semi_supervised.py b/scripts/training/sgn_semi_supervised.py new file mode 100644 index 0000000..3a76405 --- /dev/null +++ b/scripts/training/sgn_semi_supervised.py @@ -0,0 +1,100 @@ +import os +from glob import glob + +import torch +from torch_em.util import load_model +from flamingo_tools.training import mean_teacher_training + + +def get_paths(): + root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/2025-05_semi-supervised" + annotated_folders = ["annotated_2025-02", "annotated_2025-05", "empty_2025-02", "empty_2025-05"] + train_image = [] + train_label = [] + for folder in annotated_folders: + with os.scandir(os.path.join(root, folder)) as direc: + for entry in direc: + if "annotations" not in entry.name and entry.is_file(): + basename = os.path.basename(entry.name) + name_no_extension = ".".join(basename.split(".")[:-1]) + label_name = name_no_extension + "_annotations.tif" + train_image.extend(glob(os.path.join(root, folder, entry.name))) + train_label.extend(glob(os.path.join(root, folder, label_name))) + + annotated_folders = ["val_data"] + val_image = [] + val_label = [] + for folder in annotated_folders: + with os.scandir(os.path.join(root, folder)) as direc: + for entry in direc: + if "annotations" not in entry.name and entry.is_file(): + basename = os.path.basename(entry.name) + name_no_extension = ".".join(basename.split(".")[:-1]) + label_name = name_no_extension + "_annotations.tif" + val_image.extend(glob(os.path.join(root, folder, entry.name))) + val_label.extend(glob(os.path.join(root, folder, label_name))) + + domain_folders = ["domain"] + paths_domain = [] + for folder in domain_folders: + paths_domain.extend(glob(os.path.join(root, folder, "*.tif"))) + + return train_image, train_label, val_image, val_label, paths_domain[:-1], paths_domain[-1:] + + +def run_training(name): + patch_shape = (64, 128, 128) + batch_size = 20 + + super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths() + + print("super_train", len(super_train_img)) + print("super_train", len(super_train_label)) + + print("super_val", len(super_val_img)) + print("super_val", len(super_val_label)) + + print("unsuper",len(unsuper_train)) + print("unsuper",len(unsuper_train)) + + mean_teacher_training( + name=name, + unsupervised_train_paths=unsuper_train, + unsupervised_val_paths=unsuper_val, + patch_shape=patch_shape, + supervised_train_image_paths=super_train_img, + supervised_val_image_paths=super_val_img, + supervised_train_label_paths=super_train_label, + supervised_val_label_paths=super_val_label, + batch_size=batch_size, + n_iterations=int(1e5), + n_samples_train=1000, + n_samples_val=80, + ) + + +def export_model(name, export_path): + model = load_model(os.path.join("checkpoints", name), state_key="teacher") + torch.save(model, export_path) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--export_path") + parser.add_argument("--model_name", default=None) + args = parser.parse_args() + if args.model_name is None: + name = "SGN_semi-supervised" + else: + name = args.model_name + + if args.export_path is None: + run_training(name) + else: + export_model(name, args.export_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 518123b..98daefa 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -4,7 +4,7 @@ from glob import glob import torch_em -from torch_em.model import UNet3d +from flamingo_tools.training import get_supervised_loader, get_3d_model ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training" @@ -67,23 +67,12 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty): assert len(this_image_paths) == len(this_label_paths) assert len(this_image_paths) > 0 - label_transform = torch_em.transform.label.PerObjectDistanceTransform( - distances=True, boundary_distances=True, foreground=True, - ) - if split == "train": n_samples = 250 * batch_size elif split == "val": - n_samples = 20 * batch_size - - sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8) - loader = torch_em.default_segmentation_loader( - raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None, - batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform, - n_samples=n_samples, num_workers=4, shuffle=True, - sampler=sampler - ) - return loader + n_samples = 16 * batch_size + + return get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples) def main(): @@ -120,7 +109,7 @@ def main(): patch_shape = (64, 128, 128) # The U-Net. - model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid") + model = get_3d_model() # Create the training loader with train and val set. train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty) diff --git a/test/test_validation.py b/test/test_validation.py new file mode 100644 index 0000000..4bf67fd --- /dev/null +++ b/test/test_validation.py @@ -0,0 +1,67 @@ +import unittest +from shutil import rmtree + +import imageio.v3 as imageio +import pandas as pd +from skimage.measure import regionprops_table +from skimage.segmentation import relabel_sequential + + +class TestValidation(unittest.TestCase): + folder = "./tmp" + + def setUp(self): + from flamingo_tools.test_data import get_test_volume_and_segmentation + + _, self.seg_path, _ = get_test_volume_and_segmentation(self.folder) + + def tearDown(self): + try: + rmtree(self.folder) + except Exception: + pass + + def test_compute_scores_for_annotated_slice_2d(self): + from flamingo_tools.validation import compute_scores_for_annotated_slice + + segmentation = imageio.imread(self.seg_path) + segmentation = segmentation[segmentation.shape[0] // 2] + segmentation, _, _ = relabel_sequential(segmentation) + + properties = ("label", "centroid") + annotations = regionprops_table(segmentation, properties=properties) + annotations = pd.DataFrame(annotations).rename(columns={"centroid-0": "axis-0", "centroid-1": "axis-1"}) + annotations = annotations.drop(columns="label") + + result = compute_scores_for_annotated_slice(segmentation, annotations) + + # Check the results. Note: we actually get 1 FP and 1 FN because 1 of the centroids is outside the object. + self.assertEqual(result["fp"], 1) + self.assertEqual(result["fn"], 1) + self.assertEqual(result["tp"], segmentation.max() - 1) + + def test_compute_scores_for_annotated_slice_3d(self): + from flamingo_tools.validation import compute_scores_for_annotated_slice + + segmentation = imageio.imread(self.seg_path) + z0, z1 = segmentation.shape[0] // 2 - 2, segmentation.shape[0] // 2 + 2 + segmentation = segmentation[z0:z1] + segmentation, _, _ = relabel_sequential(segmentation) + + properties = ("label", "centroid") + annotations = regionprops_table(segmentation, properties=properties) + annotations = pd.DataFrame(annotations).rename( + columns={"centroid-0": "axis-0", "centroid-1": "axis-1", "centroid-2": "axis-2"} + ) + annotations = annotations.drop(columns="label") + + result = compute_scores_for_annotated_slice(segmentation, annotations) + + # Check the results. Note: we actually get 1 FP and 1 FN because 1 of the centroids is outside the object. + self.assertEqual(result["fp"], 1) + self.assertEqual(result["fn"], 1) + self.assertEqual(result["tp"], segmentation.max() - 1) + + +if __name__ == "__main__": + unittest.main()