Skip to content

Commit 9aa455b

Browse files
authored
Add support scripts for training custom model (#1086)
1 parent c99df52 commit 9aa455b

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
from glob import glob
3+
from natsort import natsorted
4+
5+
import torch
6+
7+
import micro_sam.training as sam_training
8+
from micro_sam.util import export_custom_sam_model
9+
10+
11+
def train_embl_alm_data(checkpoint_name):
12+
"""Training a MicroSAM model for https://github.com/computational-cell-analytics/micro-sam/issues/1084.
13+
"""
14+
# All hyperparameters for training.
15+
batch_size = 1
16+
patch_shape = (512, 512)
17+
n_objects_per_batch = 25
18+
device = torch.device("cuda")
19+
20+
# Get the filepaths to images and corresponding labels.
21+
image_paths = natsorted(glob(os.path.join(os.getcwd(), "data_same_size", "*.tif")))
22+
label_paths = natsorted(glob(os.path.join(os.getcwd(), "masks_same_size", "*.tif")))
23+
24+
# Next, prepare the dataloaders.
25+
kwargs = {
26+
"batch_size": batch_size,
27+
"patch_shape": patch_shape,
28+
"with_segmentation_decoder": True,
29+
"num_workers": 16,
30+
"shuffle": True,
31+
}
32+
33+
train_loader = sam_training.default_sam_loader(
34+
raw_paths=image_paths[:-5], raw_key=None, label_paths=label_paths[:-5], label_key=None, **kwargs,
35+
)
36+
val_loader = sam_training.default_sam_loader(
37+
raw_paths=image_paths[-5:], raw_key=None, label_paths=label_paths[-5:], label_key=None, **kwargs,
38+
)
39+
40+
# Run training.
41+
sam_training.train_sam(
42+
name=checkpoint_name,
43+
model_type="vit_b_lm",
44+
train_loader=train_loader,
45+
val_loader=val_loader,
46+
n_epochs=10,
47+
n_objects_per_batch=n_objects_per_batch,
48+
with_segmentation_decoder=True,
49+
device=device,
50+
)
51+
52+
53+
def main():
54+
checkpoint_name = "sam_embl_alm_fluo" # Name of the checkpoint, stored at "./checkpoints/<CHECKPOINT_NAME>"
55+
56+
train_embl_alm_data(checkpoint_name)
57+
58+
# Export the trained model.
59+
export_custom_sam_model(
60+
checkpoint_path=os.path.join("checkpoints", checkpoint_name, "best.pt"),
61+
model_type="vit_b",
62+
save_path="./finetuned_embl_alm_fluo_model.pth",
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)