From 0a8101e1dbbb1b0df1b3eec79ddf25ff97f953f9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 10 Jul 2025 17:55:22 +0200 Subject: [PATCH 1/5] Update vesicle inference --- synapse_net/inference/vesicles.py | 35 +++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/synapse_net/inference/vesicles.py b/synapse_net/inference/vesicles.py index b7eb8dfc..0964fd68 100644 --- a/synapse_net/inference/vesicles.py +++ b/synapse_net/inference/vesicles.py @@ -1,4 +1,5 @@ import time +import warnings from typing import Dict, List, Optional, Tuple, Union import elf.parallel as parallel @@ -8,6 +9,7 @@ from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler from synapse_net.inference.postprocessing.vesicles import filter_border_objects, filter_border_vesicles +from skimage.segmentation import relabel_sequential def distance_based_vesicle_segmentation( @@ -148,6 +150,10 @@ def segment_vesicles( return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. scale: The scale factor to use for rescaling the input volume before prediction. exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z. + exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter + inside of the volume. This is an alternative to post-processing with `exclude_boundary` that filters + out less vesicles at the boundary and is better suited for volumes with small context in z. + If `exclude_boundary` is also set to True, then this option will have no effect. mask: An optional mask that is used to restrict the segmentation. Returns: @@ -181,26 +187,23 @@ def segment_vesicles( foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs ) - if exclude_boundary: + if exclude_boundary and exclude_boundary_vesicles: + warnings.warn( + "You have set both 'exclude_boundary' and 'exclude_boundary_vesicles' to True." + "The 'exclude_boundary_vesicles' option will have no effect." + ) seg = filter_border_objects(seg) - if exclude_boundary_vesicles: - seg_ids = filter_border_vesicles(seg) - # Step 1: Zero out everything not in seg_ids - seg[~np.isin(seg, seg_ids)] = 0 - - # Step 2: Relabel remaining IDs to be consecutive starting from 1 - unique_ids = np.unique(seg) - unique_ids = unique_ids[unique_ids != 0] # Exclude background (0) - label_map = {old_label: new_label for new_label, old_label in enumerate(unique_ids, start=1)} + elif exclude_boundary: + seg = filter_border_objects(seg) - # Apply relabeling using a temp array (to avoid large ints in-place) - new_seg = np.zeros_like(seg, dtype=np.int32) - for old_label, new_label in label_map.items(): - new_seg[seg == old_label] = new_label + elif exclude_boundary_vesicles: + # Filter the vesicles that are at the z-border with less than their full diameter. + seg_ids = filter_border_vesicles(seg) - # Final step: replace original seg with relabelled and casted version - seg = new_seg + # Remove everything not in seg ids and relable the remaining IDs consecutively. + seg[~np.isin(seg, seg_ids)] = 0 + seg = relabel_sequential(seg)[0] seg = scaler.rescale_output(seg, is_segmentation=True) From 70628f6d11704f4455caf122ac7ac7c9e3da2bf0 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 10 Jul 2025 20:37:45 +0200 Subject: [PATCH 2/5] Implement CLI for supervised training --- setup.py | 1 + synapse_net/training/supervised_training.py | 125 +++++++++++++++++++- 2 files changed, 123 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index d0fc19f3..86989759 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ "synapse_net.run_segmentation = synapse_net.tools.cli:segmentation_cli", "synapse_net.export_to_imod_points = synapse_net.tools.cli:imod_point_cli", "synapse_net.export_to_imod_objects = synapse_net.tools.cli:imod_object_cli", + "synapse_net.run_supervised_training = synapse_net.training.supervised_training:main", ], "napari.manifest": [ "synapse_net = synapse_net:napari.yaml", diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 37566b36..e2387310 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -1,7 +1,10 @@ +import os +from glob import glob from typing import Optional, Tuple import torch import torch_em +from sklearn.model_selection import train_test_split from torch_em.model import AnisotropicUNet, UNet2d @@ -95,6 +98,7 @@ def get_supervised_loader( sampler: Optional[callable] = None, ignore_label: Optional[int] = None, label_transform: Optional[callable] = None, + label_paths: Optional[Tuple[str]] = None, **loader_kwargs, ) -> torch.utils.data.DataLoader: """Get a dataloader for supervised segmentation training. @@ -118,6 +122,8 @@ def get_supervised_loader( ignored in the loss computation. By default this option is not used. label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used. + label_paths: Optional paths containing the labels / annotations for training. + If not given, the labels are expected to be contained in the `data_paths`. loader_kwargs: Additional keyword arguments for the dataloader. Returns: @@ -155,9 +161,14 @@ def get_supervised_loader( if sampler is None: sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) + if label_paths is None: + label_paths = data_paths + elif len(label_paths) != len(data_paths): + raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") + loader = torch_em.default_segmentation_loader( data_paths, raw_key, - data_paths, label_key, sampler=sampler, + label_paths, label_key, sampler=sampler, batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, is_seg_dataset=True, label_transform=label_transform, transform=transform, num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, @@ -177,6 +188,8 @@ def supervised_training( batch_size: int = 1, lr: float = 1e-4, n_iterations: int = int(1e5), + train_label_paths: Optional[Tuple[str]] = None, + val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[callable] = None, @@ -210,6 +223,10 @@ def supervised_training( batch_size: The batch size for training. lr: The initial learning rate. n_iterations: The number of iterations to train for. + train_label_paths: Optional paths containing the label data for training. + If not given, the labels are expected to be part of `train_paths`. + val_label_paths: Optional paths containing the label data for validation. + If not given, the labels are expected to be part of `val_paths`. train_rois: Optional region of interests for training. val_rois: Optional region of interests for validation. sampler: Optional sampler for selecting blocks for training. @@ -231,11 +248,11 @@ def supervised_training( train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, n_samples=n_samples_train, rois=train_rois, sampler=sampler, ignore_label=ignore_label, label_transform=label_transform, - **loader_kwargs) + label_paths=train_label_paths, **loader_kwargs) val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, n_samples=n_samples_val, rois=val_rois, sampler=sampler, ignore_label=ignore_label, label_transform=label_transform, - **loader_kwargs) + label_paths=val_label_paths, **loader_kwargs) if check: from torch_em.util.debug import check_loader @@ -287,3 +304,105 @@ def supervised_training( metric=metric, ) trainer.fit(n_iterations) + + +def _parse_input_folder(folder, pattern, key): + files = sorted(glob(os.path.join(folder, "**", pattern))) + # Get all file extensions (general wild-cards may pick up files with multiple extensions). + extensions = [os.path.splitext(ff)[1] for ff in files] + + # If we have more than 1 file extension we just use the key that was passed, + # as it is unclear how to derive a consistent key. + if len(extensions) > 1: + return files, key + + ext = extensions[0] + extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} + + # Derive the key from the extension if the key is None. + if key is None and ext in extension_to_key: + key = extension_to_key[ext] + # If the key is None and can't be derived raise an error. + elif key is None and ext not in extension_to_key: + raise ValueError( + f"You have not passed a key for the data in {folder}, but the key could not be derived for{ext} format." + ) + # If the key was passed and doesn't match the extension raise an error. + elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: + raise ValueError( + f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." + ) + return files, key + + +def _parse_input_files(args): + train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) + train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) + if len(train_image_paths) != len(train_label_paths): + raise ValueError( + f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." + f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." + ) + + if args.val_folder is None: + if args.val_label_folder is not None: + raise ValueError("You have passed a val_label_folder, but not a val_folder.") + train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( + train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 + ) + else: + if args.val_label_folder is None: + raise ValueError("You have passed a val_folder, but not a val_label_folder.") + val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key) + val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) + + return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key + + +# TODO enable initialization with a pre-trained model. +def main(): + """@private + """ + import argparse + + parser = argparse.ArgumentParser( + description="Train a model for foreground and boundary segmentation via supervised learning." + ) + parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") + parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") + + # Folders with training data, containing raw/image data and labels. + parser.add_argument("--i", "--train_folder", required=True, help="The input folder with the training image data.") + parser.add_argument("--image_file_pattern", default="*", + help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") + parser.add_argument("--raw_key", + help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa + parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") + parser.add_argument("--label_file_pattern", default="*", + help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") + parser.add_argument("--label_key", + help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa + + # Optional folders with validation data. If not given the training data is split into train/val. + parser.add_argument("--val_folder", + help="The input folder with the validation data. If not given the training data will be split for validation") # noqa + parser.add_argument("--val_label_folder", + help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa + + # More optional argument: + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") + parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa + parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa + parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa + args = parser.parse_args() + + train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ + _parse_input_files(args) + + supervised_training( + name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, + train_label_paths=train_label_paths, val_label_paths=val_label_paths, + raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, + n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, + check=args.check, + ) From be0917a204e856629221f6125148c5b2daa92df2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 10 Jul 2025 21:42:06 +0200 Subject: [PATCH 3/5] Fix issues in training CLI and add domain adaptation CLI --- scripts/cooper/revision/az_prediction.py | 9 +- scripts/cooper/revision/common.py | 2 +- setup.py | 1 + synapse_net/training/domain_adaptation.py | 116 +++++++++++++++++++- synapse_net/training/supervised_training.py | 7 +- 5 files changed, 125 insertions(+), 10 deletions(-) diff --git a/scripts/cooper/revision/az_prediction.py b/scripts/cooper/revision/az_prediction.py index 4fca9717..6db432f6 100644 --- a/scripts/cooper/revision/az_prediction.py +++ b/scripts/cooper/revision/az_prediction.py @@ -24,7 +24,7 @@ def run_prediction(model, name, split_folder, version, split_names, in_path): for fname in tqdm(file_names): if in_path: - input_path=os.path.join(in_path, name, fname) + input_path = os.path.join(in_path, name, fname) else: input_path = os.path.join(INPUT_ROOT, name, fname) print(f"segmenting {input_path}") @@ -50,7 +50,6 @@ def run_prediction(model, name, split_folder, version, split_names, in_path): print(f"{output_key_seg} already saved") else: f.create_dataset(output_key_seg, data=seg, compression="lzf") - def get_model(version): @@ -58,7 +57,7 @@ def get_model(version): split_folder = get_split_folder(version) if version == 3: model_path = os.path.join(split_folder, "checkpoints", "3D-AZ-model-TEM_STEM_ChemFix_wichmann-v3") - elif version ==6: + elif version == 6: model_path = "/mnt/ceph-hdd/cold/nim00007/models/AZ/v6/" elif version == 7: model_path = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ/checkpoints/v7/" @@ -79,7 +78,7 @@ def main(): args = parser.parse_args() if args.model_path: - model = load_model(model_path) + model = load_model(args.model_path) else: model = get_model(args.version) @@ -87,7 +86,7 @@ def main(): for name in args.names: run_prediction(model, name, split_folder, args.version, args.splits, args.input) - + print("Finished segmenting!") diff --git a/scripts/cooper/revision/common.py b/scripts/cooper/revision/common.py index 603a73de..cfcac211 100644 --- a/scripts/cooper/revision/common.py +++ b/scripts/cooper/revision/common.py @@ -65,7 +65,7 @@ def get_split_folder(version): if version == 3: split_folder = "splits" elif version == 6: - split_folder= "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/splits" + split_folder = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/splits" else: split_folder = "models_az_thin" return split_folder diff --git a/setup.py b/setup.py index 86989759..8a53e38b 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "synapse_net.export_to_imod_points = synapse_net.tools.cli:imod_point_cli", "synapse_net.export_to_imod_objects = synapse_net.tools.cli:imod_object_cli", "synapse_net.run_supervised_training = synapse_net.training.supervised_training:main", + "synapse_net.run_domain_adaptation = synapse_net.training.domain_adaptation:main", ], "napari.manifest": [ "synapse_net = synapse_net:napari.yaml", diff --git a/synapse_net/training/domain_adaptation.py b/synapse_net/training/domain_adaptation.py index 215d7faa..50f426fd 100644 --- a/synapse_net/training/domain_adaptation.py +++ b/synapse_net/training/domain_adaptation.py @@ -1,12 +1,20 @@ import os +import tempfile +from glob import glob +from pathlib import Path from typing import Optional, Tuple +import mrcfile import torch import torch_em import torch_em.self_training as self_training +from elf.io import open_file +from sklearn.model_selection import train_test_split from .semisupervised_training import get_unsupervised_loader from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim +from ..inference.inference import get_model_path, compute_scale_from_voxel_size +from ..inference.util import _Scaler def mean_teacher_adaptation( @@ -91,7 +99,7 @@ def mean_teacher_adaptation( if os.path.isdir(source_checkpoint): model = torch_em.util.load_model(source_checkpoint) else: - model = torch.load(source_checkpoint) + model = torch.load(source_checkpoint, weights_only=False) reinit_teacher = False optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) @@ -148,3 +156,109 @@ def mean_teacher_adaptation( sampler=sampler, ) trainer.fit(n_iterations) + + +# TODO patch shapes for other models +PATCH_SHAPES = { + "vesicles_3d": [48, 256, 256], +} +"""@private +""" + + +def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir): + files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) + if len(files) == 0: + raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") + + val_fraction = 0.15 + + # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. + # And resave the volumes. + resave_val_crops = len(files) < 4 + + # We only resave the data if we resave val crops or resize the training data + resave_data = resave_val_crops or resize_training_data + if not resave_data: + train_paths, val_paths = train_test_split(files, test_size=val_fraction) + return train_paths, val_paths + + train_paths, val_paths = [], [] + for file_path in files: + file_name = os.path.basename(file_path) + data = open_file(file_path, mode="r")["data"][:] + + if resize_training_data: + with mrcfile.open(file_path) as f: + voxel_size = f.voxel_size + voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} + scale = compute_scale_from_voxel_size(voxel_size, model_name) + scaler = _Scaler(scale, verbose=False) + data = scaler.sale_input(data) + + if resave_val_crops: + n_slices = data.shape[0] + val_slice = int((1.0 - val_fraction) * n_slices) + train_data, val_data = data[:val_slice], data[val_slice:] + + train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") + with open_file(train_path, mode="w") as f: + f.create_dataset("data", data=train_data, compression="lzf") + train_paths.append(train_path) + + val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") + with open_file(val_path, mode="w") as f: + f.create_dataset("data", data=val_data, compression="lzf") + val_paths.append(val_path) + + else: + output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) + with open_file(output_path, mode="w") as f: + f.create_dataset("data", data=data, compression="lzf") + train_paths.append(output_path) + + if not resave_val_crops: + train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) + + return train_paths, val_paths + + +def _parse_patch_shape(patch_shape, model_name): + if patch_shape is None: + patch_shape = PATCH_SHAPES[model_name] + return patch_shape + + +def main(): + """@private + """ + import argparse + + parser = argparse.ArgumentParser( + description="" + ) + parser.add_argument("--name", "-n", required=True) + parser.add_argument("--input", "-i", required=True) + parser.add_argument("--pattern", "-p", default="*.mrc") + parser.add_argument("--source_model", default="vesicles_3d") + parser.add_argument("--resize_training_data", action="store_true") + parser.add_argument("--n_iterations", type=int, default=int(1e4)) + parser.add_argument("--patch_shape", nargs="+", type=int) + args = parser.parse_args() + + source_checkpoint = get_model_path(args.source_model) + patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) + with tempfile.TemporaryDirectory() as tmp_dir: + unsupervised_train_paths, unsupervised_val_paths = _get_paths( + args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir + ) + + mean_teacher_adaptation( + name=args.name, + unsupervised_train_paths=unsupervised_train_paths, + unsupervised_val_paths=unsupervised_val_paths, + patch_shape=patch_shape, + source_checkpoint=source_checkpoint, + raw_key="data", + n_iterations=args.n_iterations, + ) diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index e2387310..ac20eccf 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -307,9 +307,9 @@ def supervised_training( def _parse_input_folder(folder, pattern, key): - files = sorted(glob(os.path.join(folder, "**", pattern))) + files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) # Get all file extensions (general wild-cards may pick up files with multiple extensions). - extensions = [os.path.splitext(ff)[1] for ff in files] + extensions = list(set([os.path.splitext(ff)[1] for ff in files])) # If we have more than 1 file extension we just use the key that was passed, # as it is unclear how to derive a consistent key. @@ -372,7 +372,7 @@ def main(): parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") # Folders with training data, containing raw/image data and labels. - parser.add_argument("--i", "--train_folder", required=True, help="The input folder with the training image data.") + parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") parser.add_argument("--image_file_pattern", default="*", help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") parser.add_argument("--raw_key", @@ -394,6 +394,7 @@ def main(): parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa + parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa args = parser.parse_args() train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ From 480d71465d9c96db0d1a8f85f31c68294bb12784 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 11 Jul 2025 09:45:54 +0200 Subject: [PATCH 4/5] Update the SynapseNet trainign CLI --- doc/start_page.md | 29 +++++++--- synapse_net/training/domain_adaptation.py | 59 ++++++++++++++++----- synapse_net/training/supervised_training.py | 17 ++++-- 3 files changed, 81 insertions(+), 24 deletions(-) diff --git a/doc/start_page.md b/doc/start_page.md index 7b2a8aa7..cd752139 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -147,10 +147,12 @@ For more options supported by the IMOD exports, please run `synapse_net.export_t > Note: to use these commands you have to install IMOD. +SynapseNet also provides two CLI comamnds for training models, one for supervised network training (see [Supervised Training](#supervised-training) for details) and one for domain adaptation (see [Domain Adaptation](#domain-adaptation) for details). + ## Python Library -Using the `synapse_net` python library offers the most flexibility for using the SynapseNet functionality. +Using the `synapse_net` python library offers the most flexibility for using SynapseNet's functionality. You can find an example analysis pipeline implemented with SynapseNet [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/analysis_pipeline.py). We offer different functionality for segmenting and analyzing synapses in electron microscopy: @@ -161,17 +163,32 @@ We offer different functionality for segmenting and analyzing synapses in electr Please refer to the module documentation below for a full overview of our library's functionality. +### Supervised Training + +SynapseNet provides functionality for training a UNet for segmentation tasks using supervised learning. +In this case, you have to provide data **and** (manual) annotations for the structure(s) you want to segment. +This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py). + +We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`. Run +```bash +synapse_net.run_supervised_training -h +``` +for more information and instructions on how to use it. + ### Domain Adaptation -We provide functionality for domain adaptation. It implements a special form of neural network training that can improve segmentation for data from a different condition (e.g. different sample preparation, electron microscopy technique or different specimen), **without requiring additional annotated structures**. +SynapseNet provides functionality for (unsupervised) domain adaptation. +This functionality is implemented through a student-teacher training approach that can improve segmentation for data from a different condition (for example different sample preparation, imaging technique, or different specimen), **without requiring additional annotated structures**. Domain adaptation is implemented in `synapse_net.training.domain_adaptation`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/domain_adaptation.py). -> Note: Domain adaptation only works if the initial model you adapt already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data. +We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`. Run +```bash +synapse_net.run_domain_adaptation -h +``` +for more information and instructions on how to use it. -### Network Training +> Note: Domain adaptation only works if the initial model already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data. -We also provide functionality for 'regular' neural network training. In this case, you have to provide data **and** manual annotations for the structure(s) you want to segment. -This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py). ## Segmentation for the CryoET Data Portal diff --git a/synapse_net/training/domain_adaptation.py b/synapse_net/training/domain_adaptation.py index 50f426fd..ad3597c0 100644 --- a/synapse_net/training/domain_adaptation.py +++ b/synapse_net/training/domain_adaptation.py @@ -12,7 +12,9 @@ from sklearn.model_selection import train_test_split from .semisupervised_training import get_unsupervised_loader -from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim +from .supervised_training import ( + get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files +) from ..inference.inference import get_model_path, compute_scale_from_voxel_size from ..inference.util import _Scaler @@ -166,13 +168,11 @@ def mean_teacher_adaptation( """ -def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir): +def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) if len(files) == 0: raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") - val_fraction = 0.15 - # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. # And resave the volumes. resave_val_crops = len(files) < 4 @@ -235,23 +235,50 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="" + description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" + "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" + "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa + "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa + "You can then use this model for segmentation with the SynapseNet GUI or CLI. " + "Check out the information below for details on the arguments of this function." + ) + parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") + parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") + parser.add_argument("--file_pattern", default="*", + help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") + parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa + parser.add_argument( + "--source_model", + default="vesicles_3d", + help="The source model used for weight initialization of teacher and student model. " + "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used." + ) + parser.add_argument( + "--resize_training_data", action="store_true", + help="Whether to resize the training data to fit the voxel size of the source model's trainign data." ) - parser.add_argument("--name", "-n", required=True) - parser.add_argument("--input", "-i", required=True) - parser.add_argument("--pattern", "-p", default="*.mrc") - parser.add_argument("--source_model", default="vesicles_3d") - parser.add_argument("--resize_training_data", action="store_true") - parser.add_argument("--n_iterations", type=int, default=int(1e4)) - parser.add_argument("--patch_shape", nargs="+", type=int) + parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") + parser.add_argument( + "--patch_shape", nargs=3, type=int, + help="The patch shape for training. By default the patch shape the source model was trained with is used." + ) + + # More optional argument: + parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") + parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa + parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa + parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa + parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa + args = parser.parse_args() source_checkpoint = get_model_path(args.source_model) patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) with tempfile.TemporaryDirectory() as tmp_dir: unsupervised_train_paths, unsupervised_val_paths = _get_paths( - args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir + args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction, ) + unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) mean_teacher_adaptation( name=args.name, @@ -259,6 +286,10 @@ def main(): unsupervised_val_paths=unsupervised_val_paths, patch_shape=patch_shape, source_checkpoint=source_checkpoint, - raw_key="data", + raw_key=raw_key, n_iterations=args.n_iterations, + batch_size=args.batch_size, + n_samples_train=args.n_samples_train, + n_samples_val=args.n_samples_val, + check=args.check, ) diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index ac20eccf..328df446 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -306,8 +306,7 @@ def supervised_training( trainer.fit(n_iterations) -def _parse_input_folder(folder, pattern, key): - files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) +def _derive_key_from_files(files, key): # Get all file extensions (general wild-cards may pick up files with multiple extensions). extensions = list(set([os.path.splitext(ff)[1] for ff in files])) @@ -325,7 +324,7 @@ def _parse_input_folder(folder, pattern, key): # If the key is None and can't be derived raise an error. elif key is None and ext not in extension_to_key: raise ValueError( - f"You have not passed a key for the data in {folder}, but the key could not be derived for{ext} format." + f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." ) # If the key was passed and doesn't match the extension raise an error. elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: @@ -335,6 +334,11 @@ def _parse_input_folder(folder, pattern, key): return files, key +def _parse_input_folder(folder, pattern, key): + files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) + return _derive_key_from_files(files, key) + + def _parse_input_files(args): train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) @@ -366,7 +370,12 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="Train a model for foreground and boundary segmentation via supervised learning." + description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" + "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa + "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa + "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa + "You can then use this model for segmentation with the SynapseNet GUI or CLI. " + "Check out the information below for details on the arguments of this function." ) parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") From 1f3ad5f6da3477132983c5c9247525512bff81be Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 11 Jul 2025 10:28:03 +0200 Subject: [PATCH 5/5] Update CLI training info and add community submission info to doc --- doc/start_page.md | 26 +++++++++++++++++++++ synapse_net/training/domain_adaptation.py | 3 ++- synapse_net/training/supervised_training.py | 3 ++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/doc/start_page.md b/doc/start_page.md index cd752139..0986056d 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -196,3 +196,29 @@ We have published segmentation results for tomograms of synapses stored in the [ - [CZCDP-10330](https://cryoetdataportal.czscience.com/depositions/10330): Contains synaptic vesicle segmentations for over 50 tomograms of synaptosomes. The segmentations were made with a model domain adapted to the synaptosome tomograms. The scripts for the submissions can be found in [scripts/cryo/cryo-et-portal](https://github.com/computational-cell-analytics/synapse-net/tree/main/scripts/cryo/cryo-et-portal). + + +## Community Data Submission + +We are looking to extend and improve the SynapseNet models by training on more annotated data from electron tomography or (volume) electron microscopy. +For this, we plan to collect data from community submissions. + +If you are using SynapseNet for a task where it does not perform well, or if you would like to use it for a new segmentation task not offered by it, and have annotations for your data, then you can submit this data to us, so that we can use it to train our next version of improved models. +To do this, please create an [issue on github](https://github.com/computational-cell-analytics/synapse-net/issues) and: +- Use a title "Data submission: ..." ("..." should be a title for your data, e.g. "smooth ER in electron tomography") +- Briefly describe your data and add an image that shows the microscopy data and the segmentation masks you have. +- Make sure to describe: + - The imaging modality and the structure(s) that you have segmented. + - How many images and annotations you have / can submit and how you have created the annotations. + - You should submit at least 5 images or crops and 20 annotated objects. If you are unsure if you have enough data please go ahead and create the issue / post and we can discuss the details. + - Which data-format your images and annotations are stored in. We recommend using either `tif`, `mrc`, or `ome.zarr` files. +- Please indicate that you are willing to share the data for training purpose (see also next paragraph). + +Once you have created the post / issue, we will check if your data is suitable for submission or discuss with you how it could be extended to be suitable. Then: +- We will share an agreement for data sharing. You can find **a draft** [here](https://docs.google.com/document/d/1vf5Efp5EJcS1ivuWM4f3pO5kBqEZfJcXucXL5ot0eqg/edit?usp=sharing). +- You will be able to choose how you want to submit / publish your data. + - Share it under a CC0 license. In this case, we will use the data for re-training and also make it publicly available as soon as the next model versions become available. + - Share it for training with the option to publish it later. For example, if your data is unpublished and you want to only published once the respective publication is available. In this case, we will use the data for re-training, but not make it freely available yet. We will check with you peridiodically to see if your data can now be published. + - Share it for training only. In this case, we will re-train the model on it, but not make it publicly available. +- We encourage you to choose the first option (making the data available under CC0). +- We will then send you a link to upload your data, after you have agreed to these terms. diff --git a/synapse_net/training/domain_adaptation.py b/synapse_net/training/domain_adaptation.py index ad3597c0..031fd4af 100644 --- a/synapse_net/training/domain_adaptation.py +++ b/synapse_net/training/domain_adaptation.py @@ -240,7 +240,8 @@ def main(): "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa "You can then use this model for segmentation with the SynapseNet GUI or CLI. " - "Check out the information below for details on the arguments of this function." + "Check out the information below for details on the arguments of this function.", + formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 328df446..3a2cebc0 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -375,7 +375,8 @@ def main(): "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa "You can then use this model for segmentation with the SynapseNet GUI or CLI. " - "Check out the information below for details on the arguments of this function." + "Check out the information below for details on the arguments of this function.", + formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")