-
Notifications
You must be signed in to change notification settings - Fork 3
boundary mask for unsupervised training #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4c17c0d
6e865c3
2df5a6c
0257cdb
3e454d7
702138f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,6 @@ | |
| from ..inference.inference import get_model_path, compute_scale_from_voxel_size | ||
| from ..inference.util import _Scaler | ||
|
|
||
|
|
||
| def mean_teacher_adaptation( | ||
| name: str, | ||
| unsupervised_train_paths: Tuple[str], | ||
|
|
@@ -37,7 +36,10 @@ def mean_teacher_adaptation( | |
| n_iterations: int = int(1e4), | ||
| n_samples_train: Optional[int] = None, | ||
| n_samples_val: Optional[int] = None, | ||
| train_mask_paths: Optional[Tuple[str]] = None, | ||
| val_mask_paths: Optional[Tuple[str]] = None, | ||
| sampler: Optional[callable] = None, | ||
| device: int = 0, | ||
| ) -> None: | ||
| """Run domain adapation to transfer a network trained on a source domain for a supervised | ||
| segmentation task to perform this task on a different target domain. | ||
|
|
@@ -82,6 +84,10 @@ def mean_teacher_adaptation( | |
| based on the patch_shape and size of the volumes used for training. | ||
| n_samples_val: The number of val samples per epoch. By default this will be estimated | ||
| based on the patch_shape and size of the volumes used for validation. | ||
| train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training. | ||
| val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation. | ||
| sampler: Accept or reject patches based on a condition. | ||
|
||
| device: GPU ID for training. | ||
| """ | ||
| assert (supervised_train_paths is None) == (supervised_val_paths is None) | ||
| is_2d, _ = _determine_ndim(patch_shape) | ||
|
|
@@ -113,10 +119,21 @@ def mean_teacher_adaptation( | |
| loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() | ||
|
|
||
| unsupervised_train_loader = get_unsupervised_loader( | ||
| unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train | ||
| data_paths=unsupervised_train_paths, | ||
| raw_key=raw_key, | ||
| patch_shape=patch_shape, | ||
| batch_size=batch_size, | ||
| n_samples=n_samples_train, | ||
| boundary_mask_paths=train_mask_paths, | ||
| sampler=sampler | ||
| ) | ||
| unsupervised_val_loader = get_unsupervised_loader( | ||
| unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val | ||
| data_paths=unsupervised_val_paths, | ||
| raw_key=raw_key, | ||
| patch_shape=patch_shape, | ||
| batch_size=batch_size, | ||
| n_samples=n_samples_val, | ||
| boundary_mask_paths=val_mask_paths, sampler=sampler | ||
| ) | ||
|
|
||
| if supervised_train_paths is not None: | ||
|
|
@@ -133,7 +150,7 @@ def mean_teacher_adaptation( | |
| supervised_train_loader = None | ||
| supervised_val_loader = None | ||
|
|
||
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
| device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") | ||
| trainer = self_training.MeanTeacherTrainer( | ||
| name=name, | ||
| model=model, | ||
|
|
@@ -155,7 +172,7 @@ def mean_teacher_adaptation( | |
| device=device, | ||
| reinit_teacher=reinit_teacher, | ||
| save_root=save_root, | ||
| sampler=sampler, | ||
| sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader | ||
|
||
| ) | ||
| trainer.fit(n_iterations) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,14 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import uuid | ||
| import h5py | ||
| import torch | ||
| import torch_em | ||
| import torch_em.self_training as self_training | ||
| from torchvision import transforms | ||
|
|
||
| from synapse_net.file_utils import read_mrc | ||
| from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim | ||
|
|
||
|
|
||
|
|
@@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable: | |
| ]) | ||
| return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) | ||
|
|
||
|
|
||
| def drop_mask_channel(x): | ||
| x = x[:1] | ||
| return x | ||
|
|
||
| class ComposedTransform: | ||
| def __init__(self, *funcs): | ||
| self.funcs = funcs | ||
|
|
||
| def __call__(self, x): | ||
| for f in self.funcs: | ||
| x = f(x) | ||
| return x | ||
|
|
||
| class ChannelSplitterSampler: | ||
| def __init__(self, sampler): | ||
| self.sampler = sampler | ||
|
|
||
| def __call__(self, x): | ||
| raw, mask = x[0], x[1] | ||
| return self.sampler(raw, mask) | ||
|
|
||
| def get_unsupervised_loader( | ||
| data_paths: Tuple[str], | ||
| raw_key: str, | ||
| patch_shape: Tuple[int, int, int], | ||
| batch_size: int, | ||
| n_samples: Optional[int], | ||
| exclude_top_and_bottom: bool = False, | ||
| boundary_mask_paths: Optional[Tuple[str]] = None, | ||
| sampler: Optional[callable] = None, | ||
| exclude_top_and_bottom: bool = False, # TODO this seems unneccesary if we have a boundary mask - remove? | ||
| ) -> torch.utils.data.DataLoader: | ||
| """Get a dataloader for unsupervised segmentation training. | ||
|
|
@@ -50,19 +75,46 @@ def get_unsupervised_loader( | |
| based on the patch_shape and size of the volumes used for training. | ||
| exclude_top_and_bottom: Whether to exluce the five top and bottom slices to | ||
| avoid artifacts at the border of tomograms. | ||
| boundary_mask_paths: The filepaths to the corresponding boundary masks for each tomogram. | ||
| sampler: Accept or reject patches based on a condition. | ||
| Returns: | ||
| The PyTorch dataloader. | ||
| """ | ||
|
|
||
| # We exclude the top and bottom slices where the tomogram reconstruction is bad. | ||
| # TODO this seems unneccesary if we have a boundary mask - remove? | ||
| if exclude_top_and_bottom: | ||
| roi = np.s_[5:-5, :, :] | ||
| else: | ||
| roi = None | ||
| # stack tomograms and masks and write to temp files to use as input to RawDataset() | ||
| if boundary_mask_paths is not None: | ||
| assert len(data_paths) == len(boundary_mask_paths), \ | ||
| f"Expected equal number of data_paths and and boundary_masks_paths, got {len(data_paths)} data paths and {len(boundary_mask_paths)} mask paths." | ||
|
|
||
| stacked_paths = [] | ||
| for i, (data_path, mask_path) in enumerate(zip(data_paths, boundary_mask_paths)): | ||
| raw = read_mrc(data_path)[0] | ||
| mask = read_mrc(mask_path)[0] | ||
| stacked = np.stack([raw, mask], axis=0) | ||
|
|
||
| tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5" | ||
| with h5py.File(tmp_path, "w") as f: | ||
| f.create_dataset("raw", data=stacked, compression="gzip") | ||
| stacked_paths.append(tmp_path) | ||
|
|
||
| # update variables for RawDataset() | ||
| data_paths = tuple(stacked_paths) | ||
| base_transform = torch_em.transform.get_raw_transform() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be adapted to only act on channel 0 (the actual raw data.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's leave this for the next PR. |
||
| raw_transform = ComposedTransform(base_transform, drop_mask_channel) | ||
| sampler = ChannelSplitterSampler(sampler) | ||
| with_channels = True | ||
| else: | ||
| raw_transform = torch_em.transform.get_raw_transform() | ||
| with_channels = False | ||
| sampler = None | ||
|
|
||
| _, ndim = _determine_ndim(patch_shape) | ||
| raw_transform = torch_em.transform.get_raw_transform() | ||
| transform = torch_em.transform.get_augmentations(ndim=ndim) | ||
|
|
||
| if n_samples is None: | ||
|
|
@@ -71,18 +123,19 @@ def get_unsupervised_loader( | |
| n_samples_per_ds = int(n_samples / len(data_paths)) | ||
|
|
||
| augmentations = (weak_augmentations(), weak_augmentations()) | ||
|
|
||
| datasets = [ | ||
| torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, | ||
| augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds) | ||
| torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi, | ||
| n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations) | ||
| for path in data_paths | ||
| ] | ||
| ds = torch.utils.data.ConcatDataset(datasets) | ||
|
|
||
| num_workers = 4 * batch_size | ||
| loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) | ||
| loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, | ||
| num_workers=num_workers, shuffle=True) | ||
| return loader | ||
|
|
||
|
|
||
| # TODO: use different paths for supervised and unsupervised training | ||
| # (We are currently not using this functionality directly, so this is not a high priority) | ||
| def semisupervised_training( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor semantic comment: I think that this is no necessarily a boundary mask. I think just calling it mask is more precise.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jonathan's lamella masker uses the term "boundary mask" so that is why I used it. It makes sense because the mask defines the boundary of the signal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using three different masks in this pipeline now (gradient mask, boundary mask, membrane mask) we need to have different words to describe them. Correct me if there is a more clear way to refer to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the context here: