Skip to content

Commit 3473402

Browse files
Update histopatho training
1 parent 9a86d55 commit 3473402

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

finetuning/generalists/training/train_histopathology_generalist.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,41 @@
66

77
import torch.utils.data as data_util
88
from torch_em.data.datasets import get_lizard_dataset
9+
from torch_em.data.sampler import MinInstanceSampler
910
from micro_sam.util import export_custom_sam_model
1011

1112

1213
# TODO use other datasets than lizard
1314
def get_dataloaders(patch_shape, data_path):
1415
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
15-
dataset = get_lizard_dataset(path=data_path, patch_shape=patch_shape, label_transform=label_transform)
16+
sampler = MinInstanceSampler(min_num_instances=5)
17+
dataset = get_lizard_dataset(
18+
path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform,
19+
sampler=sampler,
20+
)
1621
train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1])
17-
train_loader = torch_em.get_data_loader(train_ds, batch_size=2)
22+
train_loader = torch_em.get_data_loader(train_ds, batch_size=1)
1823
val_loader = torch_em.get_data_loader(val_ds, batch_size=1)
1924
return train_loader, val_loader
2025

2126

22-
def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2.5e4), save_root=None):
27+
def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None):
2328
"""Example code for finetuning SAM on LiveCELL"""
2429

2530
# training settings:
2631
checkpoint_path = None # override this to start training from a custom checkpoint
2732
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
2833
patch_shape = (512, 512) # the patch shape for training
29-
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled
34+
n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled
35+
36+
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path)
3037

3138
# get the trainable segment anything model
3239
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
3340

3441
# all the stuff we need for training
3542
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
3643
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
37-
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path)
3844

3945
# this class creates all the training data for a batch (inputs, prompts and labels)
4046
convert_inputs = sam_training.ConvertToSamInputs()
@@ -74,7 +80,7 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=
7480

7581

7682
def main():
77-
input_path = ""
83+
input_path = "/scratch-grete/projects/nim00007/data/lizard"
7884
export_path = "./sam-vith-histopatho-v1.pth"
7985
finetune_histopatho(input_path, export_path)
8086

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#! /bin/bash
2+
#SBATCH -c 16
3+
#SBATCH --mem 128G
4+
#SBATCH -t 2800
5+
#SBATCH -p grete:shared
6+
#SBATCH -G A100:1
7+
#SBATCH -A nim00007
8+
#SBATCH --constraint=80gb
9+
10+
source activate sam
11+
python train_histopathology_generalist.py $@

0 commit comments

Comments
 (0)