-
Notifications
You must be signed in to change notification settings - Fork 3
boundary mask for unsupervised training #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
constantinpape
merged 6 commits into
computational-cell-analytics:main
from
stmartineau99:main
Jul 18, 2025
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4c17c0d
boundary mask for unsupervised training
6e865c3
implement lamella mask
stmartineau99 2df5a6c
boundary mask for unsupervised training
stmartineau99 0257cdb
boundary mask for unsupervised training
stmartineau99 3e454d7
boundary mask for unsupervised training
stmartineau99 702138f
Update synapse_net/training/domain_adaptation.py
constantinpape File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,14 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import uuid | ||
| import h5py | ||
| import torch | ||
| import torch_em | ||
| import torch_em.self_training as self_training | ||
| from torchvision import transforms | ||
|
|
||
| from synapse_net.file_utils import read_mrc | ||
| from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim | ||
|
|
||
|
|
||
|
|
@@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable: | |
| ]) | ||
| return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) | ||
|
|
||
| def drop_mask_channel(x): | ||
| x = x[:1] | ||
| return x | ||
|
|
||
| class ComposedTransform: | ||
| def __init__(self, *funcs): | ||
| self.funcs = funcs | ||
|
|
||
| def __call__(self, x): | ||
| for f in self.funcs: | ||
| x = f(x) | ||
| return x | ||
|
|
||
| class ChannelSplitterSampler: | ||
| def __init__(self, sampler): | ||
| self.sampler = sampler | ||
|
|
||
| def __call__(self, x): | ||
| raw, mask = x[0], x[1] | ||
| return self.sampler(raw, mask) | ||
|
|
||
| def get_unsupervised_loader( | ||
| data_paths: Tuple[str], | ||
| raw_key: str, | ||
| patch_shape: Tuple[int, int, int], | ||
| batch_size: int, | ||
| n_samples: Optional[int], | ||
| exclude_top_and_bottom: bool = False, | ||
| sample_mask_paths: Optional[Tuple[str]] = None, | ||
| sampler: Optional[callable] = None, | ||
| exclude_top_and_bottom: bool = False, | ||
| ) -> torch.utils.data.DataLoader: | ||
| """Get a dataloader for unsupervised segmentation training. | ||
|
|
||
|
|
@@ -50,19 +75,46 @@ def get_unsupervised_loader( | |
| based on the patch_shape and size of the volumes used for training. | ||
| exclude_top_and_bottom: Whether to exluce the five top and bottom slices to | ||
| avoid artifacts at the border of tomograms. | ||
| sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. | ||
| sampler: Accept or reject patches based on a condition. | ||
|
|
||
| Returns: | ||
| The PyTorch dataloader. | ||
| """ | ||
|
|
||
| # We exclude the top and bottom slices where the tomogram reconstruction is bad. | ||
| # TODO this seems unneccesary if we have a boundary mask - remove? | ||
| if exclude_top_and_bottom: | ||
| roi = np.s_[5:-5, :, :] | ||
| else: | ||
| roi = None | ||
| # stack tomograms and masks and write to temp files to use as input to RawDataset() | ||
| if sample_mask_paths is not None: | ||
| assert len(data_paths) == len(sample_mask_paths), \ | ||
| f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths." | ||
|
|
||
| stacked_paths = [] | ||
| for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)): | ||
| raw = read_mrc(data_path)[0] | ||
| mask = read_mrc(mask_path)[0] | ||
| stacked = np.stack([raw, mask], axis=0) | ||
|
|
||
| tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5" | ||
| with h5py.File(tmp_path, "w") as f: | ||
| f.create_dataset("raw", data=stacked, compression="gzip") | ||
| stacked_paths.append(tmp_path) | ||
|
|
||
| # update variables for RawDataset() | ||
| data_paths = tuple(stacked_paths) | ||
| base_transform = torch_em.transform.get_raw_transform() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be adapted to only act on channel 0 (the actual raw data.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's leave this for the next PR. |
||
| raw_transform = ComposedTransform(base_transform, drop_mask_channel) | ||
| sampler = ChannelSplitterSampler(sampler) | ||
| with_channels = True | ||
| else: | ||
| raw_transform = torch_em.transform.get_raw_transform() | ||
| with_channels = False | ||
| sampler = None | ||
|
|
||
| _, ndim = _determine_ndim(patch_shape) | ||
| raw_transform = torch_em.transform.get_raw_transform() | ||
| transform = torch_em.transform.get_augmentations(ndim=ndim) | ||
|
|
||
| if n_samples is None: | ||
|
|
@@ -71,15 +123,17 @@ def get_unsupervised_loader( | |
| n_samples_per_ds = int(n_samples / len(data_paths)) | ||
|
|
||
| augmentations = (weak_augmentations(), weak_augmentations()) | ||
|
|
||
| datasets = [ | ||
| torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, | ||
| augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds) | ||
| torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi, | ||
| n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations) | ||
| for path in data_paths | ||
| ] | ||
| ds = torch.utils.data.ConcatDataset(datasets) | ||
|
|
||
| num_workers = 4 * batch_size | ||
| loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) | ||
| loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, | ||
| num_workers=num_workers, shuffle=True) | ||
| return loader | ||
|
|
||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.