Skip to content

Commit 663d82e

Browse files
Improve train-val splits
1 parent 43bfda9 commit 663d82e

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

flamingo_tools/training/util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_supervised_loader(
2828
image_key: Optional[str] = None,
2929
label_key: Optional[str] = None,
3030
n_samples: Optional[int] = None,
31+
raw_transform: Optional[callable] = None,
3132
) -> DataLoader:
3233
"""Get a data loader for a supervised segmentation task.
3334
@@ -39,6 +40,7 @@ def get_supervised_loader(
3940
image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data.
4041
image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data.
4142
n_samples: The number of samples to use for training.
43+
raw_transform: Optional transformation for the raw data.
4244
4345
Returns:
4446
The data loader.
@@ -52,6 +54,6 @@ def get_supervised_loader(
5254
loader = torch_em.default_segmentation_loader(
5355
raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key,
5456
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
55-
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler
57+
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler, raw_transform=raw_transform,
5658
)
5759
return loader

scripts/training/train_distance_unet.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import argparse
2+
import json
23
import os
34
from datetime import datetime
45
from glob import glob
56

67
import torch_em
78
from flamingo_tools.training import get_supervised_loader, get_3d_model
9+
from sklearn.model_selection import train_test_split
810

911
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
1012

@@ -54,7 +56,7 @@ def get_image_and_label_paths_sep_folders(root):
5456
return image_paths, label_paths
5557

5658

57-
def select_paths(image_paths, label_paths, split, filter_empty):
59+
def select_paths(image_paths, label_paths, split, filter_empty, random_split=True):
5860
if filter_empty:
5961
image_paths = [imp for imp in image_paths if "empty" not in imp]
6062
label_paths = [imp for imp in label_paths if "empty" not in imp]
@@ -64,10 +66,13 @@ def select_paths(image_paths, label_paths, split, filter_empty):
6466
train_fraction = 0.85
6567

6668
n_train = int(train_fraction * n_files)
67-
if split == "train":
69+
if split == "train" and random_split:
70+
image_paths, _, label_paths, _ = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42)
71+
elif split == "train":
6872
image_paths = image_paths[:n_train]
6973
label_paths = label_paths[:n_train]
70-
74+
elif split == "val" and random_split:
75+
_, image_paths, _, label_paths = train_test_split(image_paths, label_paths, train_size=n_train, random_state=42)
7176
elif split == "val":
7277
image_paths = image_paths[n_train:]
7378
label_paths = label_paths[n_train:]
@@ -90,7 +95,11 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_fold
9095
elif split == "val":
9196
n_samples = 16 * batch_size
9297

93-
return get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples)
98+
return (
99+
get_supervised_loader(this_image_paths, this_label_paths, patch_shape, batch_size, n_samples=n_samples),
100+
this_image_paths,
101+
this_label_paths
102+
)
94103

95104

96105
def main():
@@ -131,10 +140,10 @@ def main():
131140
model = get_3d_model()
132141

133142
# Create the training loader with train and val set.
134-
train_loader = get_loader(
143+
train_loader, train_images, train_labels = get_loader(
135144
root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders
136145
)
137-
val_loader = get_loader(
146+
val_loader, val_images, val_labels = get_loader(
138147
root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders
139148
)
140149

@@ -146,8 +155,21 @@ def main():
146155

147156
loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True)
148157

149-
# Create the trainer.
158+
# Serialize the train test split.
150159
name = f"cochlea_distance_unet_{run_name}"
160+
ckpt_folder = os.path.join("checkpoints", name)
161+
os.makedirs(ckpt_folder, exist_ok=True)
162+
split_file = os.path.join(ckpt_folder, "split.json")
163+
with open(split_file, "w") as f:
164+
json.dump(
165+
{
166+
"train": {"images": train_images, "labels": train_labels},
167+
"val": {"images": val_images, "labels": val_labels},
168+
},
169+
f, sort_keys=True, indent=2
170+
)
171+
172+
# Create the trainer.
151173
trainer = torch_em.default_segmentation_trainer(
152174
name=name,
153175
model=model,

0 commit comments

Comments
 (0)