Skip to content

Commit 42cdad1

Browse files
Add README for unet training
1 parent 2a27aa1 commit 42cdad1

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

scripts/training/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# 3D U-Net Training for Cochlea Data
2+
3+
This folder contains the scripts for training a 3D U-Net for cell segmentation in the cochlea data.
4+
It contains two relevant scripts:
5+
- `check_training_data.py`, which visualizes the training data and annotations in napari.
6+
- `train_distance_unet.py`, which trains the 3D U-Net.
7+
8+
Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
9+
They will load the **image data** according to the following rules:
10+
- Files with the ending `_annotations.tif` or `_cp_masks.tif` will not be considered as image data.
11+
- The other files will be considered as image data, if a corresponding file with ending `_annotations.tif` can be found. If it cannot be found the file will be excluded; the scripts will print the name of all files being excluded.
12+
13+
The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
14+
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.

scripts/training/train_distance_unet.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch_em
77
from torch_em.model import UNet3d
88

9-
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet"
9+
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
1010

1111

1212
def get_image_and_label_paths(root):
@@ -77,7 +77,7 @@ def get_loader(root, split, patch_shape, batch_size, filter_empty):
7777
elif split == "val":
7878
n_samples = 20 * batch_size
7979

80-
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.95)
80+
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
8181
loader = torch_em.default_segmentation_loader(
8282
raw_paths=image_paths, raw_key=None, label_paths=label_paths, label_key=None,
8383
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
@@ -93,6 +93,11 @@ def main():
9393
"--root", "-i", help="The root folder with the annotated training crops.",
9494
default=ROOT_CLUSTER,
9595
)
96+
parser.add_argument(
97+
"--batch_size", "-b", help="The batch size for training. Set to 8 by default."
98+
"You may need to choose a smaller batch size to train on yoru GPU.",
99+
default=8, type=int,
100+
)
96101
parser.add_argument(
97102
"--check_loaders", "-l", action="store_true",
98103
help="Visualize the data loader output instead of starting a training run."
@@ -106,13 +111,13 @@ def main():
106111
)
107112
args = parser.parse_args()
108113
root = args.root
114+
batch_size = args.batch_size
109115
check_loaders = args.check_loaders
110116
filter_empty = args.filter_empty
111117
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
112118

113119
# Parameters for training on A100.
114120
n_iterations = 1e5
115-
batch_size = 8
116121
patch_shape = (64, 128, 128)
117122

118123
# The U-Net.

0 commit comments

Comments
 (0)