diff --git a/synapse_net/training/domain_adaptation.py b/synapse_net/training/domain_adaptation.py index 031fd4af..c57c8bf2 100644 --- a/synapse_net/training/domain_adaptation.py +++ b/synapse_net/training/domain_adaptation.py @@ -18,7 +18,6 @@ from ..inference.inference import get_model_path, compute_scale_from_voxel_size from ..inference.util import _Scaler - def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], @@ -37,9 +36,13 @@ def mean_teacher_adaptation( n_iterations: int = int(1e4), n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, - sampler: Optional[callable] = None, + train_mask_paths: Optional[Tuple[str]] = None, + val_mask_paths: Optional[Tuple[str]] = None, + patch_sampler: Optional[callable] = None, + pseudo_label_sampler: Optional[callable] = None, + device: int = 0, ) -> None: - """Run domain adapation to transfer a network trained on a source domain for a supervised + """Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain. We support different domain adaptation settings: @@ -82,6 +85,11 @@ def mean_teacher_adaptation( based on the patch_shape and size of the volumes used for training. n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation. + train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. + val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. + patch_sampler: Accept or reject patches based on a condition. + pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. + device: GPU ID for training. """ assert (supervised_train_paths is None) == (supervised_val_paths is None) is_2d, _ = _determine_ndim(patch_shape) @@ -97,7 +105,7 @@ def mean_teacher_adaptation( model = get_3d_model(out_channels=2) reinit_teacher = True else: - print("Mean teacehr training initialized from source model:", source_checkpoint) + print("Mean teacher training initialized from source model:", source_checkpoint) if os.path.isdir(source_checkpoint): model = torch_em.util.load_model(source_checkpoint) else: @@ -111,12 +119,24 @@ def mean_teacher_adaptation( pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) loss = self_training.DefaultSelfTrainingLoss() loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() - + unsupervised_train_loader = get_unsupervised_loader( - unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train + data_paths=unsupervised_train_paths, + raw_key=raw_key, + patch_shape=patch_shape, + batch_size=batch_size, + n_samples=n_samples_train, + sample_mask_paths=train_mask_paths, + sampler=patch_sampler ) unsupervised_val_loader = get_unsupervised_loader( - unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val + data_paths=unsupervised_val_paths, + raw_key=raw_key, + patch_shape=patch_shape, + batch_size=batch_size, + n_samples=n_samples_val, + sample_mask_paths=val_mask_paths, + sampler=patch_sampler ) if supervised_train_paths is not None: @@ -133,7 +153,7 @@ def mean_teacher_adaptation( supervised_train_loader = None supervised_val_loader = None - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") trainer = self_training.MeanTeacherTrainer( name=name, model=model, @@ -155,11 +175,11 @@ def mean_teacher_adaptation( device=device, reinit_teacher=reinit_teacher, save_root=save_root, - sampler=sampler, + sampler=pseudo_label_sampler, ) trainer.fit(n_iterations) - - + + # TODO patch shapes for other models PATCH_SHAPES = { "vesicles_3d": [48, 256, 256], @@ -228,7 +248,6 @@ def _parse_patch_shape(patch_shape, model_name): patch_shape = PATCH_SHAPES[model_name] return patch_shape - def main(): """@private """ @@ -293,4 +312,4 @@ def main(): n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, check=args.check, - ) + ) \ No newline at end of file diff --git a/synapse_net/training/semisupervised_training.py b/synapse_net/training/semisupervised_training.py index 1c9c0b88..084d7989 100644 --- a/synapse_net/training/semisupervised_training.py +++ b/synapse_net/training/semisupervised_training.py @@ -1,11 +1,14 @@ from typing import Optional, Tuple import numpy as np +import uuid +import h5py import torch import torch_em import torch_em.self_training as self_training from torchvision import transforms +from synapse_net.file_utils import read_mrc from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim @@ -28,6 +31,26 @@ def weak_augmentations(p: float = 0.75) -> callable: ]) return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) +def drop_mask_channel(x): + x = x[:1] + return x + +class ComposedTransform: + def __init__(self, *funcs): + self.funcs = funcs + + def __call__(self, x): + for f in self.funcs: + x = f(x) + return x + +class ChannelSplitterSampler: + def __init__(self, sampler): + self.sampler = sampler + + def __call__(self, x): + raw, mask = x[0], x[1] + return self.sampler(raw, mask) def get_unsupervised_loader( data_paths: Tuple[str], @@ -35,7 +58,9 @@ def get_unsupervised_loader( patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], - exclude_top_and_bottom: bool = False, + sample_mask_paths: Optional[Tuple[str]] = None, + sampler: Optional[callable] = None, + exclude_top_and_bottom: bool = False, ) -> torch.utils.data.DataLoader: """Get a dataloader for unsupervised segmentation training. @@ -50,19 +75,46 @@ def get_unsupervised_loader( based on the patch_shape and size of the volumes used for training. exclude_top_and_bottom: Whether to exluce the five top and bottom slices to avoid artifacts at the border of tomograms. + sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. + sampler: Accept or reject patches based on a condition. Returns: The PyTorch dataloader. """ - # We exclude the top and bottom slices where the tomogram reconstruction is bad. + # TODO this seems unneccesary if we have a boundary mask - remove? if exclude_top_and_bottom: roi = np.s_[5:-5, :, :] else: roi = None + # stack tomograms and masks and write to temp files to use as input to RawDataset() + if sample_mask_paths is not None: + assert len(data_paths) == len(sample_mask_paths), \ + f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths." + + stacked_paths = [] + for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)): + raw = read_mrc(data_path)[0] + mask = read_mrc(mask_path)[0] + stacked = np.stack([raw, mask], axis=0) + + tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5" + with h5py.File(tmp_path, "w") as f: + f.create_dataset("raw", data=stacked, compression="gzip") + stacked_paths.append(tmp_path) + + # update variables for RawDataset() + data_paths = tuple(stacked_paths) + base_transform = torch_em.transform.get_raw_transform() + raw_transform = ComposedTransform(base_transform, drop_mask_channel) + sampler = ChannelSplitterSampler(sampler) + with_channels = True + else: + raw_transform = torch_em.transform.get_raw_transform() + with_channels = False + sampler = None _, ndim = _determine_ndim(patch_shape) - raw_transform = torch_em.transform.get_raw_transform() transform = torch_em.transform.get_augmentations(ndim=ndim) if n_samples is None: @@ -71,15 +123,17 @@ def get_unsupervised_loader( n_samples_per_ds = int(n_samples / len(data_paths)) augmentations = (weak_augmentations(), weak_augmentations()) + datasets = [ - torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, - augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds) + torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi, + n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations) for path in data_paths ] ds = torch.utils.data.ConcatDataset(datasets) num_workers = 4 * batch_size - loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, + num_workers=num_workers, shuffle=True) return loader