Skip to content

Commit 4c17c0d

Browse files
SageSage
authored andcommitted
boundary mask for unsupervised training
1 parent 928f330 commit 4c17c0d

File tree

2 files changed

+83
-13
lines changed

2 files changed

+83
-13
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
1919
from ..inference.util import _Scaler
2020

21-
2221
def mean_teacher_adaptation(
2322
name: str,
2423
unsupervised_train_paths: Tuple[str],
@@ -37,7 +36,10 @@ def mean_teacher_adaptation(
3736
n_iterations: int = int(1e4),
3837
n_samples_train: Optional[int] = None,
3938
n_samples_val: Optional[int] = None,
39+
train_mask_paths: Optional[Tuple[str]] = None,
40+
val_mask_paths: Optional[Tuple[str]] = None,
4041
sampler: Optional[callable] = None,
42+
device: int = 0,
4143
) -> None:
4244
"""Run domain adapation to transfer a network trained on a source domain for a supervised
4345
segmentation task to perform this task on a different target domain.
@@ -82,6 +84,10 @@ def mean_teacher_adaptation(
8284
based on the patch_shape and size of the volumes used for training.
8385
n_samples_val: The number of val samples per epoch. By default this will be estimated
8486
based on the patch_shape and size of the volumes used for validation.
87+
train_mask_paths: Boundary masks used by the sampler to accept or reject patches for training.
88+
val_mask_paths: Boundary masks used by the sampler to accept or reject patches for validation.
89+
sampler: Accept or reject patches based on a condition.
90+
device: GPU ID for training.
8591
"""
8692
assert (supervised_train_paths is None) == (supervised_val_paths is None)
8793
is_2d, _ = _determine_ndim(patch_shape)
@@ -113,10 +119,21 @@ def mean_teacher_adaptation(
113119
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
114120

115121
unsupervised_train_loader = get_unsupervised_loader(
116-
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
122+
data_paths=unsupervised_train_paths,
123+
raw_key=raw_key,
124+
patch_shape=patch_shape,
125+
batch_size=batch_size,
126+
n_samples=n_samples_train,
127+
boundary_mask_paths=train_mask_paths,
128+
sampler=sampler
117129
)
118130
unsupervised_val_loader = get_unsupervised_loader(
119-
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
131+
data_paths=unsupervised_val_paths,
132+
raw_key=raw_key,
133+
patch_shape=patch_shape,
134+
batch_size=batch_size,
135+
n_samples=n_samples_val,
136+
boundary_mask_paths=val_mask_paths, sampler=sampler
120137
)
121138

122139
if supervised_train_paths is not None:
@@ -133,7 +150,7 @@ def mean_teacher_adaptation(
133150
supervised_train_loader = None
134151
supervised_val_loader = None
135152

136-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
153+
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
137154
trainer = self_training.MeanTeacherTrainer(
138155
name=name,
139156
model=model,
@@ -155,7 +172,7 @@ def mean_teacher_adaptation(
155172
device=device,
156173
reinit_teacher=reinit_teacher,
157174
save_root=save_root,
158-
sampler=sampler,
175+
sampler=None, # TODO currently set to none cause I didn't want to pass the same sampler used by get_unsupervised_loader
159176
)
160177
trainer.fit(n_iterations)
161178

synapse_net/training/semisupervised_training.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import Optional, Tuple
22

33
import numpy as np
4+
import uuid
5+
import h5py
46
import torch
57
import torch_em
68
import torch_em.self_training as self_training
79
from torchvision import transforms
810

11+
from synapse_net.file_utils import read_mrc
912
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
1013

1114

@@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable:
2831
])
2932
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
3033

31-
34+
def drop_mask_channel(x):
35+
x = x[:1]
36+
return x
37+
38+
class ComposedTransform:
39+
def __init__(self, *funcs):
40+
self.funcs = funcs
41+
42+
def __call__(self, x):
43+
for f in self.funcs:
44+
x = f(x)
45+
return x
46+
47+
class ChannelSplitterSampler:
48+
def __init__(self, sampler):
49+
self.sampler = sampler
50+
51+
def __call__(self, x):
52+
raw, mask = x[0], x[1]
53+
return self.sampler(raw, mask)
54+
3255
def get_unsupervised_loader(
3356
data_paths: Tuple[str],
3457
raw_key: str,
3558
patch_shape: Tuple[int, int, int],
3659
batch_size: int,
3760
n_samples: Optional[int],
38-
exclude_top_and_bottom: bool = False,
61+
boundary_mask_paths: Optional[Tuple[str]] = None,
62+
sampler: Optional[callable] = None,
63+
exclude_top_and_bottom: bool = False, # TODO this seems unneccesary if we have a boundary mask - remove?
3964
) -> torch.utils.data.DataLoader:
4065
"""Get a dataloader for unsupervised segmentation training.
4166
@@ -50,19 +75,46 @@ def get_unsupervised_loader(
5075
based on the patch_shape and size of the volumes used for training.
5176
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
5277
avoid artifacts at the border of tomograms.
78+
boundary_mask_paths: The filepaths to the corresponding boundary masks for each tomogram.
79+
sampler: Accept or reject patches based on a condition.
5380
5481
Returns:
5582
The PyTorch dataloader.
5683
"""
57-
5884
# We exclude the top and bottom slices where the tomogram reconstruction is bad.
85+
# TODO this seems unneccesary if we have a boundary mask - remove?
5986
if exclude_top_and_bottom:
6087
roi = np.s_[5:-5, :, :]
6188
else:
6289
roi = None
90+
# stack tomograms and masks and write to temp files to use as input to RawDataset()
91+
if boundary_mask_paths is not None:
92+
assert len(data_paths) == len(boundary_mask_paths), \
93+
f"Expected equal number of data_paths and and boundary_masks_paths, got {len(data_paths)} data paths and {len(boundary_mask_paths)} mask paths."
94+
95+
stacked_paths = []
96+
for i, (data_path, mask_path) in enumerate(zip(data_paths, boundary_mask_paths)):
97+
raw = read_mrc(data_path)[0]
98+
mask = read_mrc(mask_path)[0]
99+
stacked = np.stack([raw, mask], axis=0)
100+
101+
tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
102+
with h5py.File(tmp_path, "w") as f:
103+
f.create_dataset("raw", data=stacked, compression="gzip")
104+
stacked_paths.append(tmp_path)
105+
106+
# update variables for RawDataset()
107+
data_paths = tuple(stacked_paths)
108+
base_transform = torch_em.transform.get_raw_transform()
109+
raw_transform = ComposedTransform(base_transform, drop_mask_channel)
110+
sampler = ChannelSplitterSampler(sampler)
111+
with_channels = True
112+
else:
113+
raw_transform = torch_em.transform.get_raw_transform()
114+
with_channels = False
115+
sampler = None
63116

64117
_, ndim = _determine_ndim(patch_shape)
65-
raw_transform = torch_em.transform.get_raw_transform()
66118
transform = torch_em.transform.get_augmentations(ndim=ndim)
67119

68120
if n_samples is None:
@@ -71,18 +123,19 @@ def get_unsupervised_loader(
71123
n_samples_per_ds = int(n_samples / len(data_paths))
72124

73125
augmentations = (weak_augmentations(), weak_augmentations())
126+
74127
datasets = [
75-
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
76-
augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
128+
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
129+
n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
77130
for path in data_paths
78131
]
79132
ds = torch.utils.data.ConcatDataset(datasets)
80133

81134
num_workers = 4 * batch_size
82-
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
135+
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
136+
num_workers=num_workers, shuffle=True)
83137
return loader
84138

85-
86139
# TODO: use different paths for supervised and unsupervised training
87140
# (We are currently not using this functionality directly, so this is not a high priority)
88141
def semisupervised_training(

0 commit comments

Comments
 (0)