Skip to content

Commit 712714f

Browse files
committed
optional background mask for unsupervised training
1 parent 3e454d7 commit 712714f

File tree

2 files changed

+202
-183
lines changed

2 files changed

+202
-183
lines changed
Lines changed: 97 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,88 @@
11
import os
2-
import tempfile
3-
from glob import glob
4-
from pathlib import Path
52
from typing import Optional, Tuple
63

7-
import mrcfile
84
import torch
95
import torch_em
106
import torch_em.self_training as self_training
11-
from elf.io import open_file
12-
from sklearn.model_selection import train_test_split
137

148
from .semisupervised_training import get_unsupervised_loader
15-
from .supervised_training import (
16-
get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
17-
)
18-
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
19-
from ..inference.util import _Scaler
9+
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
10+
11+
class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
12+
"""Compute pseudo labels based on model predictions, typically from a teacher model.
13+
By default, assumes that the first channel contains the transformed data and the second channel contains the background mask. # TODO update description
14+
15+
Args:
16+
activation: Activation function applied to the teacher prediction.
17+
confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
18+
If None is given no mask will be computed.
19+
threshold_from_both_sides: Whether to include both values bigger than the threshold
20+
and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
21+
The former should be used for binary labels, the latter for for multiclass labels.
22+
confidence_mask_channel: A specific channel to use for computing the confidence mask.
23+
By default the confidence mask is computed across all channels independently.
24+
This is useful, if only one of the channels encodes a probability.
25+
raw_channel: # TODO add description
26+
background_mask_channel: # TODO add description
27+
"""
28+
def __init__(
29+
self,
30+
activation: Optional[torch.nn.Module] = None,
31+
confidence_threshold: Optional[float] = None,
32+
threshold_from_both_sides: bool = True,
33+
confidence_mask_channel: Optional[int] = None,
34+
raw_channel: Optional[int] = 0,
35+
background_mask_channel: Optional[int] = 1,
36+
):
37+
super().__init__(activation, confidence_threshold, threshold_from_both_sides)
38+
self.raw_channel = raw_channel
39+
self.background_mask_channel = background_mask_channel
40+
self.confidence_mask_channel = confidence_mask_channel
41+
42+
def _subtract_background(self, pseudo_labels: torch.Tensor, background_mask: torch.Tensor):
43+
bool_mask = background_mask.bool()
44+
return pseudo_labels.masked_fill(bool_mask, 0)
45+
46+
def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
47+
"""Compute pseudo-labels.
48+
49+
Args:
50+
teacher: The teacher model.
51+
input_: The input for this batch.
52+
53+
Returns:
54+
The pseudo-labels.
55+
"""
56+
if self.background_mask_channel is not None:
57+
if input_.ndim != 5:
58+
raise ValueError(f"Expect data with 5 dimensions (B, C, D, H, W), got shape {input_.shape}.")
59+
60+
if self.background_mask_channel > input_.shape[1]:
61+
raise ValueError(f"Channel index {self.background_mask_channel} is out of bounds for shape {input_.shape}.")
62+
63+
background_mask = input_[:, self.background_mask_channel].unsqueeze(1)
64+
input_ = input_[:, self.raw_channel].unsqueeze(1)
65+
66+
pseudo_labels = teacher(input_)
67+
68+
if self.activation is not None:
69+
pseudo_labels = self.activation(pseudo_labels)
70+
if self.confidence_threshold is None:
71+
label_mask = None
72+
else:
73+
mask_input = pseudo_labels if self.confidence_mask_channel is None\
74+
else pseudo_labels[self.confidence_mask_channel:(self.confidence_mask_channel+1)]
75+
label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
76+
else self._compute_label_mask_one_side(mask_input)
77+
if self.confidence_mask_channel is not None:
78+
size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
79+
label_mask = label_mask.expand(*size)
80+
81+
if self.background_mask_channel is not None:
82+
pseudo_labels = self._subtract_background(pseudo_labels, background_mask)
83+
84+
return pseudo_labels, label_mask
85+
2086

2187
def mean_teacher_adaptation(
2288
name: str,
@@ -36,13 +102,14 @@ def mean_teacher_adaptation(
36102
n_iterations: int = int(1e4),
37103
n_samples_train: Optional[int] = None,
38104
n_samples_val: Optional[int] = None,
39-
train_mask_paths: Optional[Tuple[str]] = None,
40-
val_mask_paths: Optional[Tuple[str]] = None,
105+
train_sample_mask_paths: Optional[Tuple[str]] = None,
106+
val_sample_mask_paths: Optional[Tuple[str]] = None,
107+
train_background_mask_paths: Optional[Tuple[str]] = None,
41108
patch_sampler: Optional[callable] = None,
42109
pseudo_label_sampler: Optional[callable] = None,
43110
device: int = 0,
44111
) -> None:
45-
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
112+
"""Run domain adapation to transfer a network trained on a source domain for a supervised
46113
segmentation task to perform this task on a different target domain.
47114
48115
We support different domain adaptation settings:
@@ -85,10 +152,11 @@ def mean_teacher_adaptation(
85152
based on the patch_shape and size of the volumes used for training.
86153
n_samples_val: The number of val samples per epoch. By default this will be estimated
87154
based on the patch_shape and size of the volumes used for validation.
88-
train_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
89-
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
90-
patch_sampler: Accept or reject patches based on a condition.
91-
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
155+
train_sample_mask_paths: Boundary masks used by the patch sampler to accept or reject patches for training.
156+
val_sample_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
157+
train_background_mask_paths: # TODO add description
158+
patch_sampler: A sampler for rejecting patches based on a defined conditon.
159+
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.
92160
device: GPU ID for training.
93161
"""
94162
assert (supervised_train_paths is None) == (supervised_val_paths is None)
@@ -109,24 +177,29 @@ def mean_teacher_adaptation(
109177
if os.path.isdir(source_checkpoint):
110178
model = torch_em.util.load_model(source_checkpoint)
111179
else:
112-
model = torch.load(source_checkpoint, weights_only=False)
180+
model = torch.load(source_checkpoint)
113181
reinit_teacher = False
114182

115183
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
116184
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
117185

118186
# self training functionality
119-
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
187+
if train_background_mask_paths is not None:
188+
pseudo_labeler = NewPseudoLabeler(confidence_threshold=confidence_threshold, background_mask_channel=1)
189+
else:
190+
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
191+
120192
loss = self_training.DefaultSelfTrainingLoss()
121193
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
122-
194+
123195
unsupervised_train_loader = get_unsupervised_loader(
124196
data_paths=unsupervised_train_paths,
125197
raw_key=raw_key,
126198
patch_shape=patch_shape,
127199
batch_size=batch_size,
128200
n_samples=n_samples_train,
129-
sample_mask_paths=train_mask_paths,
201+
sample_mask_paths=train_sample_mask_paths,
202+
background_mask_paths=train_background_mask_paths,
130203
sampler=patch_sampler
131204
)
132205
unsupervised_val_loader = get_unsupervised_loader(
@@ -135,7 +208,8 @@ def mean_teacher_adaptation(
135208
patch_shape=patch_shape,
136209
batch_size=batch_size,
137210
n_samples=n_samples_val,
138-
sample_mask_paths=val_mask_paths,
211+
sample_mask_paths=val_sample_mask_paths,
212+
background_mask_paths=None,
139213
sampler=patch_sampler
140214
)
141215

@@ -178,138 +252,3 @@ def mean_teacher_adaptation(
178252
sampler=pseudo_label_sampler,
179253
)
180254
trainer.fit(n_iterations)
181-
182-
183-
# TODO patch shapes for other models
184-
PATCH_SHAPES = {
185-
"vesicles_3d": [48, 256, 256],
186-
}
187-
"""@private
188-
"""
189-
190-
191-
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
192-
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
193-
if len(files) == 0:
194-
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
195-
196-
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
197-
# And resave the volumes.
198-
resave_val_crops = len(files) < 4
199-
200-
# We only resave the data if we resave val crops or resize the training data
201-
resave_data = resave_val_crops or resize_training_data
202-
if not resave_data:
203-
train_paths, val_paths = train_test_split(files, test_size=val_fraction)
204-
return train_paths, val_paths
205-
206-
train_paths, val_paths = [], []
207-
for file_path in files:
208-
file_name = os.path.basename(file_path)
209-
data = open_file(file_path, mode="r")["data"][:]
210-
211-
if resize_training_data:
212-
with mrcfile.open(file_path) as f:
213-
voxel_size = f.voxel_size
214-
voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
215-
scale = compute_scale_from_voxel_size(voxel_size, model_name)
216-
scaler = _Scaler(scale, verbose=False)
217-
data = scaler.sale_input(data)
218-
219-
if resave_val_crops:
220-
n_slices = data.shape[0]
221-
val_slice = int((1.0 - val_fraction) * n_slices)
222-
train_data, val_data = data[:val_slice], data[val_slice:]
223-
224-
train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
225-
with open_file(train_path, mode="w") as f:
226-
f.create_dataset("data", data=train_data, compression="lzf")
227-
train_paths.append(train_path)
228-
229-
val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
230-
with open_file(val_path, mode="w") as f:
231-
f.create_dataset("data", data=val_data, compression="lzf")
232-
val_paths.append(val_path)
233-
234-
else:
235-
output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
236-
with open_file(output_path, mode="w") as f:
237-
f.create_dataset("data", data=data, compression="lzf")
238-
train_paths.append(output_path)
239-
240-
if not resave_val_crops:
241-
train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
242-
243-
return train_paths, val_paths
244-
245-
246-
def _parse_patch_shape(patch_shape, model_name):
247-
if patch_shape is None:
248-
patch_shape = PATCH_SHAPES[model_name]
249-
return patch_shape
250-
251-
def main():
252-
"""@private
253-
"""
254-
import argparse
255-
256-
parser = argparse.ArgumentParser(
257-
description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
258-
"You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
259-
"synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa
260-
"The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
261-
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
262-
"Check out the information below for details on the arguments of this function.",
263-
formatter_class=argparse.RawTextHelpFormatter
264-
)
265-
parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
266-
parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
267-
parser.add_argument("--file_pattern", default="*",
268-
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
269-
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
270-
parser.add_argument(
271-
"--source_model",
272-
default="vesicles_3d",
273-
help="The source model used for weight initialization of teacher and student model. "
274-
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
275-
)
276-
parser.add_argument(
277-
"--resize_training_data", action="store_true",
278-
help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
279-
)
280-
parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
281-
parser.add_argument(
282-
"--patch_shape", nargs=3, type=int,
283-
help="The patch shape for training. By default the patch shape the source model was trained with is used."
284-
)
285-
286-
# More optional argument:
287-
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
288-
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
289-
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
290-
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
291-
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
292-
293-
args = parser.parse_args()
294-
295-
source_checkpoint = get_model_path(args.source_model)
296-
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
297-
with tempfile.TemporaryDirectory() as tmp_dir:
298-
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
299-
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
300-
)
301-
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
302-
303-
mean_teacher_adaptation(
304-
name=args.name,
305-
unsupervised_train_paths=unsupervised_train_paths,
306-
unsupervised_val_paths=unsupervised_val_paths,
307-
patch_shape=patch_shape,
308-
source_checkpoint=source_checkpoint,
309-
raw_key=raw_key,
310-
n_iterations=args.n_iterations,
311-
batch_size=args.batch_size,
312-
n_samples_train=args.n_samples_train,
313-
n_samples_val=args.n_samples_val,
314-
check=args.check,
315-
)

0 commit comments

Comments
 (0)