Skip to content

Commit 8d03fe2

Browse files
Update training scripts
1 parent b497ed4 commit 8d03fe2

File tree

5 files changed

+113
-51
lines changed

5 files changed

+113
-51
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .util import get_3d_model, get_supervised_loader
2+
from .mean_teacher_training import mean_teacher_training

flamingo_tools/training/domain_adaptation.py renamed to flamingo_tools/training/mean_teacher_training.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
import torch_em.self_training as self_training
77
from torchvision import transforms
88

9-
10-
def get_3d_model(out_channels):
11-
raise NotImplementedError
12-
13-
14-
def get_supervised_loader():
15-
raise NotImplementedError
9+
from .util import get_supervised_loader, get_3d_model
1610

1711

1812
def weak_augmentations(p: float = 0.75) -> callable:
@@ -79,15 +73,17 @@ def get_unsupervised_loader(
7973
return loader
8074

8175

82-
def mean_teacher_adaptation(
76+
def mean_teacher_training(
8377
name: str,
8478
unsupervised_train_paths: Tuple[str],
8579
unsupervised_val_paths: Tuple[str],
8680
patch_shape: Tuple[int, int, int],
8781
save_root: Optional[str] = None,
8882
source_checkpoint: Optional[str] = None,
89-
supervised_train_paths: Optional[Tuple[str]] = None,
90-
supervised_val_paths: Optional[Tuple[str]] = None,
83+
supervised_train_image_paths: Optional[Tuple[str]] = None,
84+
supervised_val_image_paths: Optional[Tuple[str]] = None,
85+
supervised_train_label_paths: Optional[Tuple[str]] = None,
86+
supervised_val_label_paths: Optional[Tuple[str]] = None,
9187
confidence_threshold: float = 0.9,
9288
raw_key: Optional[str] = None,
9389
raw_key_supervised: Optional[str] = None,
@@ -99,14 +95,13 @@ def mean_teacher_adaptation(
9995
n_samples_val: Optional[int] = None,
10096
sampler: Optional[callable] = None,
10197
) -> None:
102-
"""Run domain adapation to transfer a network trained on a source domain for a supervised
103-
segmentation task to perform this task on a different target domain.
98+
"""This function implements network training with a mean teacher approach.
10499
105-
We support different domain adaptation settings:
106-
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
107-
'supervised_val_paths' are not given.
108-
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
109-
when 'supervised_train_paths' and 'supervised_val_paths' are given.
100+
It can be used for semi-supervised learning, unsupervised domain adaptation and supervised domain adaptation.
101+
These different training modes can be used as this:
102+
- semi-supervised learning: pass 'unsupervised_train/val_paths' and 'supervised_train/val_paths'.
103+
- unsupervised domain adaptation: pass 'unsupervised_train/val_paths' and 'source_checkpoint'.
104+
- supervised domain adaptation: pass 'unsupervised_train/val_paths', 'supervised_train/val_paths', 'source_checkpoint'.
110105
111106
Args:
112107
name: The name for the checkpoint to be trained.
@@ -125,30 +120,38 @@ def mean_teacher_adaptation(
125120
If the checkpoint is not given, then both student and teacher model are initialized
126121
from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
127122
be given in order to provide training data from the source domain.
128-
supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
129-
This training data is optional. If given, it is used for unsupervised learnig and requires labels.
130-
supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
131-
This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
123+
supervised_train_image_paths: Paths to the files for the supervised image data; training split.
124+
This training data is optional. If given, it also requires labels.
125+
supervised_val_image_paths: Ppaths to the files for the supervised image data; validation split.
126+
This validation data is optional. If given, it also requires labels.
127+
supervised_train_label_paths: Filepaths to the files for the supervised label masks; training split.
128+
This training data is optional.
129+
supervised_val_label_paths: Filepaths to the files for the supervised label masks; validation split.
130+
This tvalidation data is optional.
132131
confidence_threshold: The threshold for filtering data in the unsupervised loss.
133132
The label filtering is done based on the uncertainty of network predictions, and only
134133
the data with higher certainty than this threshold is used for training.
135-
raw_key: The key that holds the raw data inside of the hdf5 or similar files.
134+
raw_key: The key that holds the raw data inside of the hdf5 or similar files;
135+
for the unsupervised training data. Set to None for tifs.
136+
raw_key_supervised: The key that holds the raw data inside of the hdf5 or similar files;
137+
for the supervised training data. Set to None for tifs.
136138
label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
137-
This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
139+
This is only required if `supervised_train_label_paths` and `supervised_val_label_paths` are given.
140+
Set to None for tifs.
138141
batch_size: The batch size for training.
139142
lr: The initial learning rate.
140143
n_iterations: The number of iterations to train for.
141144
n_samples_train: The number of train samples per epoch. By default this will be estimated
142145
based on the patch_shape and size of the volumes used for training.
143146
n_samples_val: The number of val samples per epoch. By default this will be estimated
144147
based on the patch_shape and size of the volumes used for validation.
145-
"""
146-
assert (supervised_train_paths is None) == (supervised_val_paths is None)
148+
""" # noqa
149+
assert (supervised_train_image_paths is None) == (supervised_val_image_paths is None)
147150

148151
if source_checkpoint is None:
149-
# training from scratch only makes sense if we have supervised training data
152+
# Training from scratch only makes sense if we have supervised training data
150153
# that's why we have the assertion here.
151-
assert supervised_train_paths is not None
154+
assert supervised_train_image_paths is not None
152155
model = get_3d_model(out_channels=3)
153156
reinit_teacher = True
154157
else:
@@ -174,15 +177,16 @@ def mean_teacher_adaptation(
174177
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
175178
)
176179

177-
if supervised_train_paths is not None:
178-
assert label_key is not None
180+
if supervised_train_image_paths is not None:
179181
supervised_train_loader = get_supervised_loader(
180-
supervised_train_paths, raw_key_supervised, label_key,
181-
patch_shape, batch_size, n_samples=n_samples_train,
182+
supervised_train_image_paths, supervised_train_label_paths,
183+
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_train,
184+
image_key=raw_key_supervised, label_key=label_key,
182185
)
183186
supervised_val_loader = get_supervised_loader(
184-
supervised_val_paths, raw_key_supervised, label_key,
185-
patch_shape, batch_size, n_samples=n_samples_val,
187+
supervised_val_image_paths, supervised_val_label_paths,
188+
patch_shape=patch_shape, batch_size=batch_size, n_samples=n_samples_val,
189+
image_key=raw_key_supervised, label_key=label_key,
186190
)
187191
else:
188192
supervised_train_loader = None

scripts/training/sgn_domain_adaptation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from torch_em.util import load_model
6-
from flamingo_tools.training.domain_adaptation import mean_teacher_adaptation
6+
from flamingo_tools.training import mean_teacher_training
77

88

99
def get_paths():
@@ -21,7 +21,7 @@ def run_training(name):
2121
source_checkpoint = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa
2222

2323
train_paths, val_paths = get_paths()
24-
mean_teacher_adaptation(
24+
mean_teacher_training(
2525
name=name,
2626
unsupervised_train_paths=train_paths,
2727
unsupervised_val_paths=val_paths,

scripts/training/train_distance_unet.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from glob import glob
55

66
import torch_em
7-
from torch_em.model import UNet3d
7+
from flamingo_tools.training import get_supervised_loader, get_3d_model
88

99
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
1010

@@ -67,23 +67,12 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty):
6767
assert len(this_image_paths) == len(this_label_paths)
6868
assert len(this_image_paths) > 0
6969

70-
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
71-
distances=True, boundary_distances=True, foreground=True,
72-
)
73-
7470
if split == "train":
7571
n_samples = 250 * batch_size
7672
elif split == "val":
77-
n_samples = 20 * batch_size
78-
79-
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
80-
loader = torch_em.default_segmentation_loader(
81-
raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None,
82-
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
83-
n_samples=n_samples, num_workers=4, shuffle=True,
84-
sampler=sampler
85-
)
86-
return loader
73+
n_samples = 16 * batch_size
74+
75+
return get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples)
8776

8877

8978
def main():
@@ -120,7 +109,7 @@ def main():
120109
patch_shape = (64, 128, 128)
121110

122111
# The U-Net.
123-
model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid")
112+
model = get_3d_model()
124113

125114
# Create the training loader with train and val set.
126115
train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty)

test/test_validation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
from shutil import rmtree
3+
4+
import imageio.v3 as imageio
5+
import pandas as pd
6+
from skimage.measure import regionprops_table
7+
from skimage.segmentation import relabel_sequential
8+
9+
10+
class TestValidation(unittest.TestCase):
11+
folder = "./tmp"
12+
13+
def setUp(self):
14+
from flamingo_tools.test_data import get_test_volume_and_segmentation
15+
16+
_, self.seg_path, _ = get_test_volume_and_segmentation(self.folder)
17+
18+
def tearDown(self):
19+
try:
20+
rmtree(self.folder)
21+
except Exception:
22+
pass
23+
24+
def test_compute_scores_for_annotated_slice_2d(self):
25+
from flamingo_tools.validation import compute_scores_for_annotated_slice
26+
27+
segmentation = imageio.imread(self.seg_path)
28+
segmentation = segmentation[segmentation.shape[0] // 2]
29+
segmentation, _, _ = relabel_sequential(segmentation)
30+
31+
properties = ("label", "centroid")
32+
annotations = regionprops_table(segmentation, properties=properties)
33+
annotations = pd.DataFrame(annotations).rename(columns={"centroid-0": "axis-0", "centroid-1": "axis-1"})
34+
annotations = annotations.drop(columns="label")
35+
36+
result = compute_scores_for_annotated_slice(segmentation, annotations)
37+
38+
# Check the results. Note: we actually get 1 FP and 1 FN because 1 of the centroids is outside the object.
39+
self.assertEqual(result["fp"], 1)
40+
self.assertEqual(result["fn"], 1)
41+
self.assertEqual(result["tp"], segmentation.max() - 1)
42+
43+
def test_compute_scores_for_annotated_slice_3d(self):
44+
from flamingo_tools.validation import compute_scores_for_annotated_slice
45+
46+
segmentation = imageio.imread(self.seg_path)
47+
z0, z1 = segmentation.shape[0] // 2 - 2, segmentation.shape[0] // 2 + 2
48+
segmentation = segmentation[z0:z1]
49+
segmentation, _, _ = relabel_sequential(segmentation)
50+
51+
properties = ("label", "centroid")
52+
annotations = regionprops_table(segmentation, properties=properties)
53+
annotations = pd.DataFrame(annotations).rename(
54+
columns={"centroid-0": "axis-0", "centroid-1": "axis-1", "centroid-2": "axis-2"}
55+
)
56+
annotations = annotations.drop(columns="label")
57+
58+
result = compute_scores_for_annotated_slice(segmentation, annotations)
59+
60+
# Check the results. Note: we actually get 1 FP and 1 FN because 1 of the centroids is outside the object.
61+
self.assertEqual(result["fp"], 1)
62+
self.assertEqual(result["fn"], 1)
63+
self.assertEqual(result["tp"], segmentation.max() - 1)
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

0 commit comments

Comments
 (0)