Skip to content

Commit 2a27aa1

Browse files
Update training scripts
1 parent 87688cd commit 2a27aa1

File tree

4 files changed

+116
-66
lines changed

4 files changed

+116
-66
lines changed

scripts/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ conda install -c conda-forge mobie_utils
1313
## Training
1414

1515
Contains the scripts for training a U-Net that predicts foreground probabilties and normalized object distances.
16+
It also contains a documentation for how to run training on new annotated data.
1617

1718

1819
## Prediction

scripts/data_transfer/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,10 @@ Try to automate via https://github.com/jborean93/smbprotocol see `sync_smb.py` f
3333

3434
For transfering back MoBIE results.
3535
...
36+
37+
# Data Transfer Huisken
38+
39+
See "Transfer via smbclient" above:
40+
```
41+
smbclient \\\\wfs-biologie-spezial.top.gwdg.de\\UBM1-all\$\\ -U GWDG\\pape41
42+
```
Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,57 @@
1+
import argparse
12
import os
2-
from glob import glob
33

44
import imageio.v3 as imageio
55
import napari
66
import numpy as np
77

8-
root = "/home/pape/Work/data/moser/lightsheet"
8+
from train_distance_unet import get_image_and_label_paths
9+
from tqdm import tqdm
910

11+
# Root folder on my laptop.
12+
# This is just for convenience, so that I don't have to pass
13+
# the root argument during development.
14+
ROOT_CP = "/home/pape/Work/data/moser/lightsheet"
1015

11-
def check_visually(check_downsampled=False):
12-
if check_downsampled:
13-
images = sorted(glob(os.path.join(root, "images_s2", "*.tif")))
14-
masks = sorted(glob(os.path.join(root, "masks_s2", "*.tif")))
15-
else:
16-
images = sorted(glob(os.path.join(root, "images", "*.tif")))
17-
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
18-
assert len(images) == len(masks)
1916

20-
for im, mask in zip(images, masks):
21-
print(im)
17+
def check_visually(images, labels):
18+
for im, label in tqdm(zip(images, labels), total=len(images)):
2219

2320
vol = imageio.imread(im)
24-
seg = imageio.imread(mask).astype("uint32")
21+
seg = imageio.imread(label).astype("uint32")
2522

2623
v = napari.Viewer()
27-
v.add_image(vol)
28-
v.add_labels(seg)
24+
v.add_image(vol, name="pv-channel")
25+
v.add_labels(seg, name="annotations")
26+
folder, name = os.path.split(im)
27+
folder = os.path.basename(folder)
28+
v.title = f"{folder}/{name}"
2929
napari.run()
3030

3131

32-
def check_labels():
33-
masks = sorted(glob(os.path.join(root, "masks", "*.tif")))
34-
for mask_path in masks:
35-
labels = imageio.imread(mask_path)
32+
def check_labels(images, labels):
33+
for label_path in labels:
34+
labels = imageio.imread(label_path)
3635
n_labels = len(np.unique(labels))
37-
print(mask_path, n_labels)
36+
print(label_path, n_labels)
37+
38+
39+
def main():
40+
parser = argparse.ArgumentParser()
41+
parser.add_argument(
42+
"--root", "-i", help="The root folder with the annotated training crops.",
43+
default=ROOT_CP,
44+
)
45+
parser.add_argument("--check_labels", "-l", action="store_true")
46+
args = parser.parse_args()
47+
root = args.root
48+
49+
images, labels = get_image_and_label_paths(root)
50+
51+
check_visually(images, labels)
52+
if args.check_labels:
53+
check_labels(images, labels)
3854

3955

4056
if __name__ == "__main__":
41-
check_visually(True)
42-
# check_labels()
57+
main()
Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,72 @@
1+
import argparse
12
import os
3+
from datetime import datetime
24
from glob import glob
35

46
import torch_em
5-
67
from torch_em.model import UNet3d
78

8-
# DATA_ROOT = "/home/pape/Work/data/moser/lightsheet"
9-
DATA_ROOT = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
9+
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
10+
11+
12+
def get_image_and_label_paths(root):
13+
exclude_names = ["annotations", "cp_masks"]
14+
all_image_paths = sorted(glob(os.path.join(root, "**/**.tif"), recursive=True))
15+
all_image_paths = [
16+
path for path in all_image_paths if not any(exclude in path for exclude in exclude_names)
17+
]
18+
19+
image_paths, label_paths = [], []
20+
label_extensions = ["_annotations.tif"]
21+
for path in all_image_paths:
22+
folder, fname = os.path.split(path)
23+
fname = os.path.splitext(fname)[0]
24+
label_path = None
25+
for ext in label_extensions:
26+
candidate_label_path = os.path.join(folder, f"{fname}{ext}")
27+
if os.path.exists(candidate_label_path):
28+
label_path = candidate_label_path
29+
break
30+
31+
if label_path is None:
32+
print("Did not find annotations for", path)
33+
print("This image will not be used for training.")
34+
else:
35+
image_paths.append(path)
36+
label_paths.append(label_path)
37+
38+
assert len(image_paths) == len(label_paths)
39+
return image_paths, label_paths
1040

1141

12-
def get_paths(image_paths, label_paths, split, filter_empty):
42+
def select_paths(image_paths, label_paths, split, filter_empty):
1343
if filter_empty:
1444
image_paths = [imp for imp in image_paths if "empty" not in imp]
1545
label_paths = [imp for imp in label_paths if "empty" not in imp]
1646
assert len(image_paths) == len(label_paths)
1747

1848
n_files = len(image_paths)
1949

20-
train_fraction = 0.8
21-
val_fraction = 0.1
50+
train_fraction = 0.85
2251

2352
n_train = int(train_fraction * n_files)
24-
n_val = int(val_fraction * n_files)
2553
if split == "train":
2654
image_paths = image_paths[:n_train]
2755
label_paths = label_paths[:n_train]
2856

2957
elif split == "val":
30-
image_paths = image_paths[n_train:(n_train + n_val)]
31-
label_paths = label_paths[n_train:(n_train + n_val)]
58+
image_paths = image_paths[n_train:]
59+
label_paths = label_paths[n_train:]
3260

3361
return image_paths, label_paths
3462

3563

36-
def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"]):
37-
image_paths, label_paths = [], []
38-
39-
if "default" in train_on:
40-
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images", "*.tif")))
41-
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks", "*.tif")))
42-
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
43-
image_paths.extend(this_image_paths)
44-
label_paths.extend(this_label_paths)
64+
def get_loader(root, split, patch_shape, batch_size, filter_empty):
65+
image_paths, label_paths = get_image_and_label_paths(root)
66+
this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty)
4567

46-
if "downsampled" in train_on:
47-
all_image_paths = sorted(glob(os.path.join(DATA_ROOT, "images_s2", "*.tif")))
48-
all_label_paths = sorted(glob(os.path.join(DATA_ROOT, "masks_s2", "*.tif")))
49-
this_image_paths, this_label_paths = get_paths(all_image_paths, all_label_paths, split, filter_empty)
50-
image_paths.extend(this_image_paths)
51-
label_paths.extend(this_label_paths)
68+
assert len(this_image_paths) == len(this_label_paths)
69+
assert len(this_image_paths) > 0
5270

5371
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
5472
distances=True, boundary_distances=True, foreground=True,
@@ -69,26 +87,40 @@ def get_loader(split, patch_shape, batch_size, filter_empty, train_on=["default"
6987
return loader
7088

7189

72-
def main(check_loaders=False):
73-
# Parameters for training:
90+
def main():
91+
parser = argparse.ArgumentParser()
92+
parser.add_argument(
93+
"--root", "-i", help="The root folder with the annotated training crops.",
94+
default=ROOT_CLUSTER,
95+
)
96+
parser.add_argument(
97+
"--check_loaders", "-l", action="store_true",
98+
help="Visualize the data loader output instead of starting a training run."
99+
)
100+
parser.add_argument(
101+
"--filter_empty", "-f", action="store_true",
102+
help="Whether to exclude blocks with empty annotations from the training process."
103+
)
104+
parser.add_argument(
105+
"--name", help="Optional name for the model to be trained. If not given the current date is used."
106+
)
107+
args = parser.parse_args()
108+
root = args.root
109+
check_loaders = args.check_loaders
110+
filter_empty = args.filter_empty
111+
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
112+
113+
# Parameters for training on A100.
74114
n_iterations = 1e5
75115
batch_size = 8
76-
filter_empty = False
77-
train_on = ["downsampled"]
78-
# train_on = ["downsampled", "default"]
79-
80-
patch_shape = (32, 128, 128) if "downsampled" in train_on else (64, 128, 128)
116+
patch_shape = (64, 128, 128)
81117

82118
# The U-Net.
83119
model = UNet3d(in_channels=1, out_channels=3, initial_features=32, final_activation="Sigmoid")
84120

85121
# Create the training loader with train and val set.
86-
train_loader = get_loader(
87-
"train", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
88-
)
89-
val_loader = get_loader(
90-
"val", patch_shape, batch_size, filter_empty=filter_empty, train_on=train_on
91-
)
122+
train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty)
123+
val_loader = get_loader(root, "val", patch_shape, batch_size, filter_empty=filter_empty)
92124

93125
if check_loaders:
94126
from torch_em.util.debug import check_loader
@@ -99,12 +131,7 @@ def main(check_loaders=False):
99131
loss = torch_em.loss.distance_based.DiceBasedDistanceLoss(mask_distances_in_bg=True)
100132

101133
# Create the trainer.
102-
name = "cochlea_distance_unet"
103-
if filter_empty:
104-
name += "-filter-empty"
105-
if train_on == ["downsampled"]:
106-
name += "-train-downsampled"
107-
134+
name = f"cochlea_distance_unet_{run_name}"
108135
trainer = torch_em.default_segmentation_trainer(
109136
name=name,
110137
model=model,
@@ -123,4 +150,4 @@ def main(check_loaders=False):
123150

124151

125152
if __name__ == "__main__":
126-
main(check_loaders=False)
153+
main()

0 commit comments

Comments
 (0)