Skip to content

Commit 534ab38

Browse files
authored
Merge pull request #31 from computational-cell-analytics/domain-adaptation
Add domain adaptation script Domain adaptation has been implemented for unsupervised learning and semi-supervised learning. Semi-supervised learning has been tested for IHC and was slightly worse than the supervised approach. The reason for this needs to be investigated.
2 parents c549e65 + 539d2c1 commit 534ab38

File tree

8 files changed

+592
-16
lines changed

8 files changed

+592
-16
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .util import get_3d_model, get_supervised_loader
2+
from .mean_teacher_training import mean_teacher_training
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import os
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
import torch_em
6+
import torch_em.self_training as self_training
7+
from torchvision import transforms
8+
9+
from .util import get_supervised_loader, get_3d_model
10+
11+
12+
def weak_augmentations(p: float = 0.75) -> callable:
13+
"""The weak augmentations used in the unsupervised data loader.
14+
15+
Args:
16+
p: The probability for applying one of the augmentations.
17+
18+
Returns:
19+
The transformation function applying the augmentation.
20+
"""
21+
norm = torch_em.transform.raw.standardize
22+
aug = transforms.Compose([
23+
norm,
24+
transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
25+
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
26+
scale=(0, 0.15), clip_kwargs=False)], p=p
27+
),
28+
])
29+
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
30+
31+
32+
def get_unsupervised_loader(
33+
data_paths: Tuple[str],
34+
raw_key: Optional[str],
35+
patch_shape: Tuple[int, int, int],
36+
batch_size: int,
37+
n_samples: Optional[int],
38+
) -> torch.utils.data.DataLoader:
39+
"""Get a dataloader for unsupervised segmentation training.
40+
41+
Args:
42+
data_paths: The filepaths to the hdf5 files containing the training data.
43+
raw_key: The key that holds the raw data inside of the hdf5.
44+
patch_shape: The patch shape used for a training example.
45+
In order to run 2d training pass a patch shape with a singleton in the z-axis,
46+
e.g. 'patch_shape = [1, 512, 512]'.
47+
batch_size: The batch size for training.
48+
n_samples: The number of samples per epoch. By default this will be estimated
49+
based on the patch_shape and size of the volumes used for training.
50+
51+
Returns:
52+
The PyTorch dataloader.
53+
"""
54+
raw_transform = torch_em.transform.get_raw_transform()
55+
transform = torch_em.transform.get_augmentations(ndim=3)
56+
57+
if n_samples is None:
58+
n_samples_per_ds = None
59+
else:
60+
n_samples_per_ds = int(n_samples / len(data_paths))
61+
62+
augmentations = (weak_augmentations(), weak_augmentations())
63+
datasets = [
64+
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
65+
augmentations=augmentations, ndim=3, n_samples=n_samples_per_ds)
66+
for path in data_paths
67+
]
68+
ds = torch.utils.data.ConcatDataset(datasets)
69+
70+
# num_workers = 4 * batch_size
71+
num_workers = batch_size
72+
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
73+
return loader
74+
75+
76+
def mean_teacher_training(
77+
name: str,
78+
unsupervised_train_paths: Tuple[str],
79+
unsupervised_val_paths: Tuple[str],
80+
patch_shape: Tuple[int, int, int],
81+
save_root: Optional[str] = None,
82+
source_checkpoint: Optional[str] = None,
83+
supervised_train_image_paths: Optional[Tuple[str]] = None,
84+
supervised_val_image_paths: Optional[Tuple[str]] = None,
85+
supervised_train_label_paths: Optional[Tuple[str]] = None,
86+
supervised_val_label_paths: Optional[Tuple[str]] = None,
87+
confidence_threshold: float = 0.9,
88+
raw_key: Optional[str] = None,
89+
raw_key_supervised: Optional[str] = None,
90+
label_key: Optional[str] = None,
91+
batch_size: int = 1,
92+
lr: float = 1e-4,
93+
n_iterations: int = int(1e4),
94+
n_samples_train: Optional[int] = None,
95+
n_samples_val: Optional[int] = None,
96+
sampler: Optional[callable] = None,
97+
) -> None:
98+
"""This function implements network training with a mean teacher approach.
99+
100+
It can be used for semi-supervised learning, unsupervised domain adaptation and supervised domain adaptation.
101+
These different training modes can be used as this:
102+
- semi-supervised learning: pass 'unsupervised_train/val_paths' and 'supervised_train/val_paths'.
103+
- unsupervised domain adaptation: pass 'unsupervised_train/val_paths' and 'source_checkpoint'.
104+
- supervised domain adaptation: pass 'unsupervised_train/val_paths', 'supervised_train/val_paths', 'source_checkpoint'.
105+
106+
Args:
107+
name: The name for the checkpoint to be trained.
108+
unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
109+
for the training data in the target domain.
110+
This training data is used for unsupervised learning, so it does not require labels.
111+
unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
112+
for the validation data in the target domain.
113+
This validation data is used for unsupervised learning, so it does not require labels.
114+
patch_shape: The patch shape used for a training example.
115+
In order to run 2d training pass a patch shape with a singleton in the z-axis,
116+
e.g. 'patch_shape = [1, 512, 512]'.
117+
save_root: Folder where the checkpoint will be saved.
118+
source_checkpoint: Checkpoint to the initial model trained on the source domain.
119+
This is used to initialize the teacher model.
120+
If the checkpoint is not given, then both student and teacher model are initialized
121+
from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
122+
be given in order to provide training data from the source domain.
123+
supervised_train_image_paths: Paths to the files for the supervised image data; training split.
124+
This training data is optional. If given, it also requires labels.
125+
supervised_val_image_paths: Ppaths to the files for the supervised image data; validation split.
126+
This validation data is optional. If given, it also requires labels.
127+
supervised_train_label_paths: Filepaths to the files for the supervised label masks; training split.
128+
This training data is optional.
129+
supervised_val_label_paths: Filepaths to the files for the supervised label masks; validation split.
130+
This tvalidation data is optional.
131+
confidence_threshold: The threshold for filtering data in the unsupervised loss.
132+
The label filtering is done based on the uncertainty of network predictions, and only
133+
the data with higher certainty than this threshold is used for training.
134+
raw_key: The key that holds the raw data inside of the hdf5 or similar files;
135+
for the unsupervised training data. Set to None for tifs.
136+
raw_key_supervised: The key that holds the raw data inside of the hdf5 or similar files;
137+
for the supervised training data. Set to None for tifs.
138+
label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
139+
This is only required if `supervised_train_label_paths` and `supervised_val_label_paths` are given.
140+
Set to None for tifs.
141+
batch_size: The batch size for training.
142+
lr: The initial learning rate.
143+
n_iterations: The number of iterations to train for.
144+
n_samples_train: The number of train samples per epoch. By default this will be estimated
145+
based on the patch_shape and size of the volumes used for training.
146+
n_samples_val: The number of val samples per epoch. By default this will be estimated
147+
based on the patch_shape and size of the volumes used for validation.
148+
""" # noqa
149+
assert (supervised_train_image_paths is None) == (supervised_val_image_paths is None)
150+
151+
if source_checkpoint is None:
152+
# Training from scratch only makes sense if we have supervised training data
153+
# that's why we have the assertion here.
154+
assert supervised_train_image_paths is not None
155+
model = get_3d_model(out_channels=3)
156+
reinit_teacher = True
157+
else:
158+
print("Mean teacehr training initialized from source model:", source_checkpoint)
159+
if os.path.isdir(source_checkpoint):
160+
model = torch_em.util.load_model(source_checkpoint)
161+
else:
162+
model = torch.load(source_checkpoint, weights_only=False)
163+
reinit_teacher = False
164+
165+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
166+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
167+
168+
# self training functionality
169+
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold, mask_channel=0)
170+
loss = self_training.DefaultSelfTrainingLoss()
171+
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
172+
173+
unsupervised_train_loader = get_unsupervised_loader(
174+
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
175+
)
176+
unsupervised_val_loader = get_unsupervised_loader(
177+
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
178+
)
179+
180+
if supervised_train_image_paths is not None:
181+
supervised_train_loader = get_supervised_loader(
182+
supervised_train_image_paths, supervised_train_label_paths,
183+
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_train,
184+
image_key=raw_key_supervised, label_key=label_key,
185+
)
186+
supervised_val_loader = get_supervised_loader(
187+
supervised_val_image_paths, supervised_val_label_paths,
188+
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_val,
189+
image_key=raw_key_supervised, label_key=label_key,
190+
)
191+
else:
192+
supervised_train_loader = None
193+
supervised_val_loader = None
194+
195+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
196+
trainer = self_training.MeanTeacherTrainer(
197+
name=name,
198+
model=model,
199+
optimizer=optimizer,
200+
lr_scheduler=scheduler,
201+
pseudo_labeler=pseudo_labeler,
202+
unsupervised_loss=loss,
203+
unsupervised_loss_and_metric=loss_and_metric,
204+
supervised_train_loader=supervised_train_loader,
205+
unsupervised_train_loader=unsupervised_train_loader,
206+
supervised_val_loader=supervised_val_loader,
207+
unsupervised_val_loader=unsupervised_val_loader,
208+
supervised_loss=loss,
209+
supervised_loss_and_metric=loss_and_metric,
210+
logger=self_training.SelfTrainingTensorboardLogger,
211+
mixed_precision=True,
212+
log_image_interval=100,
213+
compile_model=False,
214+
device=device,
215+
reinit_teacher=reinit_teacher,
216+
save_root=save_root,
217+
sampler=sampler,
218+
)
219+
trainer.fit(n_iterations)

flamingo_tools/training/util.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Optional, Sequence, Tuple
2+
3+
import torch.nn as nn
4+
import torch_em
5+
from torch_em.model import UNet3d
6+
from torch.utils.data import DataLoader
7+
8+
9+
def get_3d_model(out_channels: int = 3, final_activation: Optional[str] = "Sigmoid") -> nn.Module:
10+
"""Get a 3D U-Net for segmentation or detection tasks.
11+
12+
Args:
13+
out_channels: The number of output channels of the network.
14+
final_activation: The activation applied to the last layer.
15+
Set to 'None' for no activation; by default this applies a Sigmoid activation.
16+
17+
Returns:
18+
The 3D U-Net.
19+
"""
20+
return UNet3d(in_channels=1, out_channels=out_channels, initial_features=32, final_activation=final_activation)
21+
22+
23+
def get_supervised_loader(
24+
image_paths: Sequence[str],
25+
label_paths: Sequence[str],
26+
patch_shape: Tuple[int, int, int],
27+
batch_size: int,
28+
image_key: Optional[str] = None,
29+
label_key: Optional[str] = None,
30+
n_samples: Optional[int] = None,
31+
) -> DataLoader:
32+
"""Get a data loader for a supervised segmentation task.
33+
34+
Args:
35+
image_paths: The filepaths to the image data. These can be stored either in tif or in hdf5/zarr/n5.
36+
image_paths: The filepaths to the label masks. These can be stored either in tif or in hdf5/zarr/n5.
37+
patch_shape: The 3D patch shape for training.
38+
batch_Size: The batch size for training.
39+
image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data.
40+
image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data.
41+
n_samples: The number of samples to use for training.
42+
43+
Returns:
44+
The data loader.
45+
"""
46+
assert len(image_paths) == len(label_paths)
47+
assert len(image_paths) > 0
48+
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
49+
distances=True, boundary_distances=True, foreground=True,
50+
)
51+
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
52+
loader = torch_em.default_segmentation_loader(
53+
raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key,
54+
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
55+
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler
56+
)
57+
return loader
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-05-IHC_semi-supervised"
11+
annotated_folders = ["annotated_train", "empty"]
12+
train_image = []
13+
train_label = []
14+
for folder in annotated_folders:
15+
with os.scandir(os.path.join(root, folder)) as direc:
16+
for entry in direc:
17+
if "annotations" not in entry.name and entry.is_file():
18+
basename = os.path.basename(entry.name)
19+
name_no_extension = ".".join(basename.split(".")[:-1])
20+
label_name = name_no_extension + "_annotations.tif"
21+
train_image.extend(glob(os.path.join(root, folder, entry.name)))
22+
train_label.extend(glob(os.path.join(root, folder, label_name)))
23+
24+
annotated_folders = ["annotated_val"]
25+
val_image = []
26+
val_label = []
27+
for folder in annotated_folders:
28+
with os.scandir(os.path.join(root, folder)) as direc:
29+
for entry in direc:
30+
if "annotations" not in entry.name and entry.is_file():
31+
basename = os.path.basename(entry.name)
32+
name_no_extension = ".".join(basename.split(".")[:-1])
33+
label_name = name_no_extension + "_annotations.tif"
34+
val_image.extend(glob(os.path.join(root, folder, entry.name)))
35+
val_label.extend(glob(os.path.join(root, folder, label_name)))
36+
37+
domain_folders = ["domain_Aleyna", "domain_Lennart"]
38+
paths_domain = []
39+
for folder in domain_folders:
40+
paths_domain.extend(glob(os.path.join(root, folder, "*.tif")))
41+
42+
return train_image, train_label, val_image, val_label, paths_domain[:-2], paths_domain[-2:]
43+
44+
45+
def run_training(name):
46+
patch_shape = (64, 128, 128)
47+
batch_size = 8
48+
49+
super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths()
50+
51+
mean_teacher_training(
52+
name=name,
53+
unsupervised_train_paths=unsuper_train,
54+
unsupervised_val_paths=unsuper_val,
55+
patch_shape=patch_shape,
56+
supervised_train_image_paths=super_train_img,
57+
supervised_val_image_paths=super_val_img,
58+
supervised_train_label_paths=super_train_label,
59+
supervised_val_label_paths=super_val_label,
60+
batch_size=batch_size,
61+
n_iterations=int(1e5),
62+
n_samples_train=1000,
63+
n_samples_val=80,
64+
)
65+
66+
67+
def export_model(name, export_path):
68+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
69+
torch.save(model, export_path)
70+
71+
72+
def main():
73+
import argparse
74+
75+
parser = argparse.ArgumentParser()
76+
parser.add_argument("--export_path")
77+
args = parser.parse_args()
78+
name = "IHC_semi-supervised_2025-05-22"
79+
if args.export_path is None:
80+
run_training(name)
81+
else:
82+
export_model(name, args.export_path)
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)