Skip to content

Commit 2fd0344

Browse files
Implement micro-sam training script
1 parent 170e449 commit 2fd0344

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

scripts/training/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This folder contains the scripts for training a 3D U-Net for cell segmentation i
44
It contains two relevant scripts:
55
- `check_training_data.py`, which visualizes the training data and annotations in napari.
66
- `train_distance_unet.py`, which trains the 3D U-Net.
7+
- `train_micro_sam.py`, which fine-tunes a micro-sam model on the data.
78

89
Both scripts accept the argument `-i /path/to/data`, to specify the root folder with the training data. For example, run `python train_distance_unet.py -i /path/to/data` for training. The scripts will consider all tif files in the sub-folders of the root folder for training.
910
They will load the **image data** according to the following rules:
@@ -12,3 +13,5 @@ They will load the **image data** according to the following rules:
1213

1314
The training script will save the trained model in `checkpoints/cochlea_distance_unet_<CURRENT_DATE>`, e.g. `checkpoints/cochlea_distance_unet_20250115`.
1415
For further options for the scripts run `python check_training_data.py -h` / `python train_distance_unet.py -h`.
16+
17+
The script `train_micro_sam.py` works similar to the U-Net training script. It saves the finetuned model for annotation with `micro_sam` to `checkpoints/`.

scripts/training/train_distance_unet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def select_paths(image_paths, label_paths, split, filter_empty):
4646
assert len(image_paths) == len(label_paths)
4747

4848
n_files = len(image_paths)
49-
5049
train_fraction = 0.85
5150

5251
n_train = int(train_fraction * n_files)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import argparse
2+
from datetime import datetime
3+
4+
import numpy as np
5+
import torch_em
6+
from micro_sam.training import default_sam_loader, train_sam
7+
from train_distance_unet import get_image_and_label_paths, select_paths
8+
9+
ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"
10+
11+
12+
def raw_transform(x):
13+
x = x.astype("float32")
14+
min_, max_ = np.percentile(x, 1), np.percentile(x, 99)
15+
x -= min_
16+
x /= max_
17+
x = np.clip(x, 0, 1)
18+
return x * 255
19+
20+
21+
def main():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument(
24+
"--root", "-i", help="The root folder with the annotated training crops.",
25+
default=ROOT_CLUSTER,
26+
)
27+
parser.add_argument(
28+
"--name", help="Optional name for the model to be trained. If not given the current date is used."
29+
)
30+
parser.add_argument(
31+
"--n_objects_per_batch", "-n", type=int, default=15,
32+
help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory."
33+
"The default value is 15."
34+
)
35+
args = parser.parse_args()
36+
37+
root = args.root
38+
run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name
39+
name = f"cochlea_distance_unet_{run_name}"
40+
n_objects_per_batch = args.n_objects_per_batch
41+
42+
image_paths, label_paths = get_image_and_label_paths(root)
43+
train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)
44+
val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)
45+
46+
patch_shape = (1, 256, 256)
47+
sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10)
48+
max_sampling_attempts = 2500
49+
50+
train_loader = default_sam_loader(
51+
raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,
52+
patch_shape=patch_shape, with_segmentation_decoder=True,
53+
raw_transform=raw_transform, sampler=sampler, min_size=10,
54+
num_workers=6, batch_size=1, is_train=True,
55+
max_sampling_attempts=max_sampling_attempts,
56+
)
57+
val_loader = default_sam_loader(
58+
raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,
59+
patch_shape=patch_shape, with_segmentation_decoder=True,
60+
raw_transform=raw_transform, sampler=sampler, min_size=10,
61+
num_workers=6, batch_size=1, is_train=False,
62+
max_sampling_attempts=max_sampling_attempts,
63+
)
64+
65+
train_sam(
66+
name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,
67+
n_epochs=50, n_objects_per_batch=n_objects_per_batch,
68+
)
69+
70+
71+
if __name__ == "__main__":
72+
main()

0 commit comments

Comments
 (0)