Skip to content

Commit aa8cd15

Browse files
committed
background mask for unsupervised training
1 parent e6f86fc commit aa8cd15

File tree

2 files changed

+1
-82
lines changed

2 files changed

+1
-82
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
1919
from ..inference.util import _Scaler
2020

21-
<<<<<<< HEAD
2221
class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
2322
"""Compute pseudo labels based on model predictions, typically from a teacher model.
2423
By default, assumes that the first channel contains the transformed data and the second channel contains the background mask. # TODO update description
@@ -94,9 +93,6 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
9493

9594
return pseudo_labels, label_mask
9695

97-
98-
=======
99-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
10096
def mean_teacher_adaptation(
10197
name: str,
10298
unsupervised_train_paths: Tuple[str],
@@ -115,14 +111,11 @@ def mean_teacher_adaptation(
115111
n_iterations: int = int(1e4),
116112
n_samples_train: Optional[int] = None,
117113
n_samples_val: Optional[int] = None,
118-
<<<<<<< HEAD
119114
train_sample_mask_paths: Optional[Tuple[str]] = None,
120115
val_sample_mask_paths: Optional[Tuple[str]] = None,
121116
train_background_mask_paths: Optional[Tuple[str]] = None,
122-
=======
123117
train_mask_paths: Optional[Tuple[str]] = None,
124118
val_mask_paths: Optional[Tuple[str]] = None,
125-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
126119
patch_sampler: Optional[callable] = None,
127120
pseudo_label_sampler: Optional[callable] = None,
128121
device: int = 0,
@@ -170,18 +163,15 @@ def mean_teacher_adaptation(
170163
based on the patch_shape and size of the volumes used for training.
171164
n_samples_val: The number of val samples per epoch. By default this will be estimated
172165
based on the patch_shape and size of the volumes used for validation.
173-
<<<<<<< HEAD
174166
train_sample_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
175167
val_sample_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
176168
train_background_mask_paths: # TODO add description
177169
patch_sampler: A sampler for rejecting patches based on a defined conditon.
178170
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.
179-
=======
180171
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
181172
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
182173
patch_sampler: Accept or reject patches based on a condition.
183174
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
184-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
185175
device: GPU ID for training.
186176
"""
187177
assert (supervised_train_paths is None) == (supervised_val_paths is None)
@@ -216,23 +206,16 @@ def mean_teacher_adaptation(
216206

217207
loss = self_training.DefaultSelfTrainingLoss()
218208
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
219-
<<<<<<< HEAD
220-
221-
=======
222-
223-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
209+
224210
unsupervised_train_loader = get_unsupervised_loader(
225211
data_paths=unsupervised_train_paths,
226212
raw_key=raw_key,
227213
patch_shape=patch_shape,
228214
batch_size=batch_size,
229215
n_samples=n_samples_train,
230-
<<<<<<< HEAD
231216
sample_mask_paths=train_sample_mask_paths,
232217
background_mask_paths=train_background_mask_paths,
233-
=======
234218
sample_mask_paths=train_mask_paths,
235-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
236219
sampler=patch_sampler
237220
)
238221
unsupervised_val_loader = get_unsupervised_loader(
@@ -241,12 +224,9 @@ def mean_teacher_adaptation(
241224
patch_shape=patch_shape,
242225
batch_size=batch_size,
243226
n_samples=n_samples_val,
244-
<<<<<<< HEAD
245227
sample_mask_paths=val_sample_mask_paths,
246228
background_mask_paths=None,
247-
=======
248229
sample_mask_paths=val_mask_paths,
249-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
250230
sampler=patch_sampler
251231
)
252232

@@ -289,12 +269,7 @@ def mean_teacher_adaptation(
289269
sampler=pseudo_label_sampler,
290270
)
291271
trainer.fit(n_iterations)
292-
<<<<<<< HEAD
293272

294-
=======
295-
296-
297-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
298273
# TODO patch shapes for other models
299274
PATCH_SHAPES = {
300275
"vesicles_3d": [48, 256, 256],

synapse_net/training/semisupervised_training.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def weak_augmentations(p: float = 0.75) -> callable:
3030
])
3131
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
3232

33-
<<<<<<< HEAD
3433
class DropChannel:
3534
def __init__(self, channel: int):
3635
self.channel = channel
@@ -59,12 +58,6 @@ def __call__(self, data):
5958
output[self.transform_channel] = self.base_transform(data[self.transform_channel])
6059
return output
6160

62-
=======
63-
def drop_mask_channel(x):
64-
x = x[:1]
65-
return x
66-
67-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
6861
class ComposedTransform:
6962
def __init__(self, *funcs):
7063
self.funcs = funcs
@@ -74,7 +67,6 @@ def __call__(self, x):
7467
x = f(x)
7568
return x
7669

77-
<<<<<<< HEAD
7870
class ChannelWiseAugmentations:
7971
def __init__(self, base_augmentations: callable, transform_channel: int = 0):
8072
self.base_augmentations = base_augmentations
@@ -91,40 +83,29 @@ def __call__(self, data):
9183
output[self.transform_channel] = self.base_augmentations(data[self.transform_channel])
9284
return output
9385

94-
class ChannelSplitterSampler:
95-
=======
9686
class ChannelSplitterSampler:
97-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
9887
def __init__(self, sampler):
9988
self.sampler = sampler
10089

10190
def __call__(self, x):
10291
raw, mask = x[0], x[1]
10392
return self.sampler(raw, mask)
104-
<<<<<<< HEAD
10593

10694
def get_stacked_path(inputs: List[np.ndarray]):
10795
stacked = np.stack(inputs, axis=0)
10896
tmp_path = f"/tmp/stacked_{uuid.uuid4().hex}.h5"
10997
with h5py.File(tmp_path, "w") as f:
11098
f.create_dataset("raw", data=stacked, compression="gzip")
11199
return tmp_path
112-
=======
113-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
114100

115101
def get_unsupervised_loader(
116102
data_paths: Tuple[str],
117103
raw_key: str,
118104
patch_shape: Tuple[int, int, int],
119105
batch_size: int,
120-
<<<<<<< HEAD
121106
n_samples: Optional[int] = None,
122107
sample_mask_paths: Optional[Tuple[str]] = None,
123108
background_mask_paths: Tuple[str] = None,
124-
=======
125-
n_samples: Optional[int],
126-
sample_mask_paths: Optional[Tuple[str]] = None,
127-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
128109
sampler: Optional[callable] = None,
129110
exclude_top_and_bottom: bool = False,
130111
) -> torch.utils.data.DataLoader:
@@ -142,10 +123,7 @@ def get_unsupervised_loader(
142123
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
143124
avoid artifacts at the border of tomograms.
144125
sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
145-
<<<<<<< HEAD
146126
background_mask_paths: TODO add description
147-
=======
148-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
149127
sampler: Accept or reject patches based on a condition.
150128
151129
Returns:
@@ -157,40 +135,11 @@ def get_unsupervised_loader(
157135
roi = np.s_[5:-5, :, :]
158136
else:
159137
roi = None
160-
<<<<<<< HEAD
161138

162139
# initialize class instances
163140
base_transform = torch_em.transform.get_raw_transform()
164141
channelwise_raw_transform = ChannelWiseRawTransform(base_transform)
165142
drop_channel = DropChannel(channel = 1)
166-
=======
167-
# stack tomograms and masks and write to temp files to use as input to RawDataset()
168-
if sample_mask_paths is not None:
169-
assert len(data_paths) == len(sample_mask_paths), \
170-
f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."
171-
172-
stacked_paths = []
173-
for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
174-
raw = read_mrc(data_path)[0]
175-
mask = read_mrc(mask_path)[0]
176-
stacked = np.stack([raw, mask], axis=0)
177-
178-
tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
179-
with h5py.File(tmp_path, "w") as f:
180-
f.create_dataset("raw", data=stacked, compression="gzip")
181-
stacked_paths.append(tmp_path)
182-
183-
# update variables for RawDataset()
184-
data_paths = tuple(stacked_paths)
185-
base_transform = torch_em.transform.get_raw_transform()
186-
raw_transform = ComposedTransform(base_transform, drop_mask_channel)
187-
sampler = ChannelSplitterSampler(sampler)
188-
with_channels = True
189-
else:
190-
raw_transform = torch_em.transform.get_raw_transform()
191-
with_channels = False
192-
sampler = None
193-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
194143

195144
# get configurations
196145
has_sample_mask = sample_mask_paths is not None
@@ -256,11 +205,6 @@ def get_unsupervised_loader(
256205
else:
257206
n_samples_per_ds = int(n_samples / len(data_paths))
258207

259-
<<<<<<< HEAD
260-
=======
261-
augmentations = (weak_augmentations(), weak_augmentations())
262-
263-
>>>>>>> 9c252ed35b26397634947e7bec01ccb222751af6
264208
datasets = [
265209
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
266210
n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)

0 commit comments

Comments
 (0)