Skip to content

Commit 0ec2643

Browse files
Merge pull request #193 from computational-cell-analytics/histopatho-model
Histopatho model
2 parents eb331b9 + 3473402 commit 0ec2643

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
3+
import micro_sam.training as sam_training
4+
import torch
5+
import torch_em
6+
7+
import torch.utils.data as data_util
8+
from torch_em.data.datasets import get_lizard_dataset
9+
from torch_em.data.sampler import MinInstanceSampler
10+
from micro_sam.util import export_custom_sam_model
11+
12+
13+
# TODO use other datasets than lizard
14+
def get_dataloaders(patch_shape, data_path):
15+
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
16+
sampler = MinInstanceSampler(min_num_instances=5)
17+
dataset = get_lizard_dataset(
18+
path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform,
19+
sampler=sampler,
20+
)
21+
train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1])
22+
train_loader = torch_em.get_data_loader(train_ds, batch_size=1)
23+
val_loader = torch_em.get_data_loader(val_ds, batch_size=1)
24+
return train_loader, val_loader
25+
26+
27+
def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None):
28+
"""Example code for finetuning SAM on LiveCELL"""
29+
30+
# training settings:
31+
checkpoint_path = None # override this to start training from a custom checkpoint
32+
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
33+
patch_shape = (512, 512) # the patch shape for training
34+
n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled
35+
36+
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path)
37+
38+
# get the trainable segment anything model
39+
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
40+
41+
# all the stuff we need for training
42+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
43+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
44+
45+
# this class creates all the training data for a batch (inputs, prompts and labels)
46+
convert_inputs = sam_training.ConvertToSamInputs()
47+
48+
checkpoint_name = "sam-histopatho-v1"
49+
# the trainer which performs training and validation (implemented using "torch_em")
50+
trainer = sam_training.SamTrainer(
51+
name=checkpoint_name,
52+
save_root=save_root,
53+
train_loader=train_loader,
54+
val_loader=val_loader,
55+
model=model,
56+
optimizer=optimizer,
57+
# currently we compute loss batch-wise, else we pass channelwise True
58+
loss=torch_em.loss.DiceLoss(channelwise=False),
59+
metric=torch_em.loss.DiceLoss(),
60+
device=device,
61+
lr_scheduler=scheduler,
62+
logger=sam_training.SamLogger,
63+
log_image_interval=10,
64+
mixed_precision=True,
65+
convert_inputs=convert_inputs,
66+
n_objects_per_batch=n_objects_per_batch,
67+
n_sub_iteration=8,
68+
compile_model=False
69+
)
70+
trainer.fit(iterations)
71+
if export_path is not None:
72+
checkpoint_path = os.path.join(
73+
"" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt"
74+
)
75+
export_custom_sam_model(
76+
checkpoint_path=checkpoint_path,
77+
model_type=model_type,
78+
save_path=export_path,
79+
)
80+
81+
82+
def main():
83+
input_path = "/scratch-grete/projects/nim00007/data/lizard"
84+
export_path = "./sam-vith-histopatho-v1.pth"
85+
finetune_histopatho(input_path, export_path)
86+
87+
88+
if __name__ == "__main__":
89+
main()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#! /bin/bash
2+
#SBATCH -c 16
3+
#SBATCH --mem 128G
4+
#SBATCH -t 2800
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
8+
#SBATCH --constraint=80gb
9+
10+
source activate sam
11+
python train_histopathology_generalist.py $@

0 commit comments

Comments
 (0)