Skip to content

Commit 3ec31ad

Browse files
Update training scripts
1 parent d3b59cd commit 3ec31ad

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

flamingo_tools/training/util.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Optional, Sequence, Tuple
1+
from typing import List, Optional, Sequence, Tuple
22

33
import torch.nn as nn
44
import torch_em
5-
from torch_em.model import UNet3d
5+
from torch_em.model import UNet3d, AnisotropicUNet
66
from torch.utils.data import DataLoader
77

88

9-
def get_3d_model(out_channels: int = 3, final_activation: Optional[str] = "Sigmoid") -> nn.Module:
9+
def get_3d_model(
10+
out_channels: int = 3, final_activation: Optional[str] = "Sigmoid", scale_factors: Optional[List[List[int]]] = None
11+
) -> nn.Module:
1012
"""Get a 3D U-Net for segmentation or detection tasks.
1113
1214
Args:
@@ -17,7 +19,13 @@ def get_3d_model(out_channels: int = 3, final_activation: Optional[str] = "Sigmo
1719
Returns:
1820
The 3D U-Net.
1921
"""
20-
return UNet3d(in_channels=1, out_channels=out_channels, initial_features=32, final_activation=final_activation)
22+
if scale_factors is None:
23+
return UNet3d(in_channels=1, out_channels=out_channels, initial_features=32, final_activation=final_activation)
24+
else:
25+
return AnisotropicUNet(
26+
in_channels=1, out_channels=out_channels, initial_features=32,
27+
final_activation=final_activation, scale_factors=scale_factors
28+
)
2129

2230

2331
def get_supervised_loader(
@@ -57,6 +65,6 @@ def get_supervised_loader(
5765
loader = torch_em.default_segmentation_loader(
5866
raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key,
5967
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
60-
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler
68+
n_samples=n_samples, num_workers=8, shuffle=True, sampler=sampler
6169
)
6270
return loader

scripts/training/train_distance_unet.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@ def get_image_and_label_paths(root):
3838
label_paths.append(label_path)
3939

4040
assert len(image_paths) == len(label_paths)
41-
return image_paths, label_paths
41+
return image_paths, label_paths, None
4242

4343

4444
def get_image_and_label_paths_sep_folders(root):
4545
image_paths = sorted(glob(os.path.join(root, "images", "**", "*.tif"), recursive=True))
4646
label_paths = sorted(glob(os.path.join(root, "labels", "**", "*.tif"), recursive=True))
4747
assert len(image_paths) == len(label_paths)
4848

49-
return image_paths, label_paths
49+
stratify = [os.path.basename(os.path.dirname(f)) for f in image_paths]
50+
return image_paths, label_paths, stratify
5051

5152

52-
def select_paths(image_paths, label_paths, split, filter_empty, random_split=True):
53+
def select_paths(image_paths, label_paths, split, filter_empty, stratify, random_split=True):
5354
if filter_empty:
5455
image_paths = [imp for imp in image_paths if "empty" not in imp]
5556
label_paths = [imp for imp in label_paths if "empty" not in imp]
@@ -60,12 +61,16 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru
6061

6162
n_train = int(train_fraction * n_files)
6263
if split == "train" and random_split:
63-
image_paths, _, label_paths, _ = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42)
64+
image_paths, _, label_paths, _ = train_test_split(
65+
image_paths, label_paths, train_size=n_train, random_state=42, stratify=stratify
66+
)
6467
elif split == "train":
6568
image_paths = image_paths[:n_train]
6669
label_paths = label_paths[:n_train]
6770
elif split == "val" and random_split:
68-
_, image_paths, _, label_paths = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42)
71+
_, image_paths, _, label_paths = train_test_split(
72+
image_paths, label_paths, train_size=n_train, random_state=42, stratify=stratify
73+
)
6974
elif split == "val":
7075
image_paths = image_paths[n_train:]
7176
label_paths = label_paths[n_train:]
@@ -75,13 +80,14 @@ def select_paths(image_paths, label_paths, split, filter_empty, random_split=Tru
7580

7681
def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders, anisotropy):
7782
if separate_folders:
78-
image_paths, label_paths = get_image_and_label_paths_sep_folders(root)
83+
image_paths, label_paths, stratify = get_image_and_label_paths_sep_folders(root)
7984
else:
80-
image_paths, label_paths = get_image_and_label_paths(root)
81-
this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty)
85+
image_paths, label_paths, stratify = get_image_and_label_paths(root)
86+
this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty, stratify=stratify)
8287

8388
assert len(this_image_paths) == len(this_label_paths)
8489
assert len(this_image_paths) > 0
90+
print(split, ":", len(this_image_paths), "image crops")
8591

8692
if split == "train":
8793
n_samples = 250 * batch_size
@@ -133,10 +139,11 @@ def main():
133139

134140
# Parameters for training on A100.
135141
n_iterations = int(1e5)
136-
patch_shape = (48, 128, 128)
142+
patch_shape = (48, 128, 128) if anisotropy is None else (24, 128, 128)
137143

138144
# The U-Net.
139-
model = get_3d_model()
145+
scale_factors = None if args.anisotropy is None else [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
146+
model = get_3d_model(scale_factors=scale_factors)
140147

141148
# Create the training loader with train and val set.
142149
train_loader, train_images, train_labels = get_loader(

0 commit comments

Comments
 (0)