Skip to content

Commit b497ed4

Browse files
Add domain adaptation functionality
1 parent 1481244 commit b497ed4

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed

flamingo_tools/training/__init__.py

Whitespace-only changes.
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import os
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
import torch_em
6+
import torch_em.self_training as self_training
7+
from torchvision import transforms
8+
9+
10+
def get_3d_model(out_channels):
11+
raise NotImplementedError
12+
13+
14+
def get_supervised_loader():
15+
raise NotImplementedError
16+
17+
18+
def weak_augmentations(p: float = 0.75) -> callable:
19+
"""The weak augmentations used in the unsupervised data loader.
20+
21+
Args:
22+
p: The probability for applying one of the augmentations.
23+
24+
Returns:
25+
The transformation function applying the augmentation.
26+
"""
27+
norm = torch_em.transform.raw.standardize
28+
aug = transforms.Compose([
29+
norm,
30+
transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
31+
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
32+
scale=(0, 0.15), clip_kwargs=False)], p=p
33+
),
34+
])
35+
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
36+
37+
38+
def get_unsupervised_loader(
39+
data_paths: Tuple[str],
40+
raw_key: Optional[str],
41+
patch_shape: Tuple[int, int, int],
42+
batch_size: int,
43+
n_samples: Optional[int],
44+
) -> torch.utils.data.DataLoader:
45+
"""Get a dataloader for unsupervised segmentation training.
46+
47+
Args:
48+
data_paths: The filepaths to the hdf5 files containing the training data.
49+
raw_key: The key that holds the raw data inside of the hdf5.
50+
patch_shape: The patch shape used for a training example.
51+
In order to run 2d training pass a patch shape with a singleton in the z-axis,
52+
e.g. 'patch_shape = [1, 512, 512]'.
53+
batch_size: The batch size for training.
54+
n_samples: The number of samples per epoch. By default this will be estimated
55+
based on the patch_shape and size of the volumes used for training.
56+
57+
Returns:
58+
The PyTorch dataloader.
59+
"""
60+
raw_transform = torch_em.transform.get_raw_transform()
61+
transform = torch_em.transform.get_augmentations(ndim=3)
62+
63+
if n_samples is None:
64+
n_samples_per_ds = None
65+
else:
66+
n_samples_per_ds = int(n_samples / len(data_paths))
67+
68+
augmentations = (weak_augmentations(), weak_augmentations())
69+
datasets = [
70+
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
71+
augmentations=augmentations, ndim=3, n_samples=n_samples_per_ds)
72+
for path in data_paths
73+
]
74+
ds = torch.utils.data.ConcatDataset(datasets)
75+
76+
# num_workers = 4 * batch_size
77+
num_workers = batch_size
78+
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
79+
return loader
80+
81+
82+
def mean_teacher_adaptation(
83+
name: str,
84+
unsupervised_train_paths: Tuple[str],
85+
unsupervised_val_paths: Tuple[str],
86+
patch_shape: Tuple[int, int, int],
87+
save_root: Optional[str] = None,
88+
source_checkpoint: Optional[str] = None,
89+
supervised_train_paths: Optional[Tuple[str]] = None,
90+
supervised_val_paths: Optional[Tuple[str]] = None,
91+
confidence_threshold: float = 0.9,
92+
raw_key: Optional[str] = None,
93+
raw_key_supervised: Optional[str] = None,
94+
label_key: Optional[str] = None,
95+
batch_size: int = 1,
96+
lr: float = 1e-4,
97+
n_iterations: int = int(1e4),
98+
n_samples_train: Optional[int] = None,
99+
n_samples_val: Optional[int] = None,
100+
sampler: Optional[callable] = None,
101+
) -> None:
102+
"""Run domain adapation to transfer a network trained on a source domain for a supervised
103+
segmentation task to perform this task on a different target domain.
104+
105+
We support different domain adaptation settings:
106+
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
107+
'supervised_val_paths' are not given.
108+
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
109+
when 'supervised_train_paths' and 'supervised_val_paths' are given.
110+
111+
Args:
112+
name: The name for the checkpoint to be trained.
113+
unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
114+
for the training data in the target domain.
115+
This training data is used for unsupervised learning, so it does not require labels.
116+
unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
117+
for the validation data in the target domain.
118+
This validation data is used for unsupervised learning, so it does not require labels.
119+
patch_shape: The patch shape used for a training example.
120+
In order to run 2d training pass a patch shape with a singleton in the z-axis,
121+
e.g. 'patch_shape = [1, 512, 512]'.
122+
save_root: Folder where the checkpoint will be saved.
123+
source_checkpoint: Checkpoint to the initial model trained on the source domain.
124+
This is used to initialize the teacher model.
125+
If the checkpoint is not given, then both student and teacher model are initialized
126+
from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
127+
be given in order to provide training data from the source domain.
128+
supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
129+
This training data is optional. If given, it is used for unsupervised learnig and requires labels.
130+
supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
131+
This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
132+
confidence_threshold: The threshold for filtering data in the unsupervised loss.
133+
The label filtering is done based on the uncertainty of network predictions, and only
134+
the data with higher certainty than this threshold is used for training.
135+
raw_key: The key that holds the raw data inside of the hdf5 or similar files.
136+
label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
137+
This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
138+
batch_size: The batch size for training.
139+
lr: The initial learning rate.
140+
n_iterations: The number of iterations to train for.
141+
n_samples_train: The number of train samples per epoch. By default this will be estimated
142+
based on the patch_shape and size of the volumes used for training.
143+
n_samples_val: The number of val samples per epoch. By default this will be estimated
144+
based on the patch_shape and size of the volumes used for validation.
145+
"""
146+
assert (supervised_train_paths is None) == (supervised_val_paths is None)
147+
148+
if source_checkpoint is None:
149+
# training from scratch only makes sense if we have supervised training data
150+
# that's why we have the assertion here.
151+
assert supervised_train_paths is not None
152+
model = get_3d_model(out_channels=3)
153+
reinit_teacher = True
154+
else:
155+
print("Mean teacehr training initialized from source model:", source_checkpoint)
156+
if os.path.isdir(source_checkpoint):
157+
model = torch_em.util.load_model(source_checkpoint)
158+
else:
159+
model = torch.load(source_checkpoint, weights_only=False)
160+
reinit_teacher = False
161+
162+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
163+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
164+
165+
# self training functionality
166+
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold, mask_channel=0)
167+
loss = self_training.DefaultSelfTrainingLoss()
168+
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
169+
170+
unsupervised_train_loader = get_unsupervised_loader(
171+
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
172+
)
173+
unsupervised_val_loader = get_unsupervised_loader(
174+
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
175+
)
176+
177+
if supervised_train_paths is not None:
178+
assert label_key is not None
179+
supervised_train_loader = get_supervised_loader(
180+
supervised_train_paths, raw_key_supervised, label_key,
181+
patch_shape, batch_size, n_samples=n_samples_train,
182+
)
183+
supervised_val_loader = get_supervised_loader(
184+
supervised_val_paths, raw_key_supervised, label_key,
185+
patch_shape, batch_size, n_samples=n_samples_val,
186+
)
187+
else:
188+
supervised_train_loader = None
189+
supervised_val_loader = None
190+
191+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
192+
trainer = self_training.MeanTeacherTrainer(
193+
name=name,
194+
model=model,
195+
optimizer=optimizer,
196+
lr_scheduler=scheduler,
197+
pseudo_labeler=pseudo_labeler,
198+
unsupervised_loss=loss,
199+
unsupervised_loss_and_metric=loss_and_metric,
200+
supervised_train_loader=supervised_train_loader,
201+
unsupervised_train_loader=unsupervised_train_loader,
202+
supervised_val_loader=supervised_val_loader,
203+
unsupervised_val_loader=unsupervised_val_loader,
204+
supervised_loss=loss,
205+
supervised_loss_and_metric=loss_and_metric,
206+
logger=self_training.SelfTrainingTensorboardLogger,
207+
mixed_precision=True,
208+
log_image_interval=100,
209+
compile_model=False,
210+
device=device,
211+
reinit_teacher=reinit_teacher,
212+
save_root=save_root,
213+
sampler=sampler,
214+
)
215+
trainer.fit(n_iterations)

0 commit comments

Comments
 (0)