Skip to content

Commit 70fec20

Browse files
authored
Update finetuning example (#338)
Update finetunung example for hela
1 parent 46191c3 commit 70fec20

File tree

2 files changed

+63
-32
lines changed

2 files changed

+63
-32
lines changed

examples/finetuning/finetune_hela.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import os
2-
32
import numpy as np
3+
44
import torch
5+
56
import torch_em
7+
from torch_em.model import UNETR
8+
from torch_em.loss import DiceBasedDistanceLoss
9+
from torch_em.transform.label import PerObjectDistanceTransform
610

711
import micro_sam.training as sam_training
8-
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
912
from micro_sam.util import export_custom_sam_model
13+
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
14+
1015

1116
DATA_FOLDER = "data"
1217

1318

14-
def get_dataloader(split, patch_shape, batch_size):
19+
def get_dataloader(split, patch_shape, batch_size, train_instance_segmentation):
1520
"""Return train or val data loader for finetuning SAM.
1621
1722
The data loader must be a torch data loader that retuns `x, y` tensors,
@@ -52,18 +57,27 @@ def get_dataloader(split, patch_shape, batch_size):
5257
else:
5358
roi = np.s_[70:, :, :]
5459

60+
if train_instance_segmentation:
61+
# Computes the distance transform for objects to perform end-to-end automatic instance segmentation.
62+
label_transform = PerObjectDistanceTransform(
63+
distances=True, boundary_distances=True, directed_distances=False,
64+
foreground=True, instances=True, min_size=25
65+
)
66+
else:
67+
label_transform = torch_em.transform.label.connected_components
68+
5569
loader = torch_em.default_segmentation_loader(
5670
raw_paths=image_dir, raw_key=raw_key,
5771
label_paths=segmentation_dir, label_key=label_key,
5872
patch_shape=patch_shape, batch_size=batch_size,
5973
ndim=2, is_seg_dataset=True, rois=roi,
60-
label_transform=torch_em.transform.label.connected_components,
74+
label_transform=label_transform,
6175
num_workers=8, shuffle=True, raw_transform=sam_training.identity,
6276
)
6377
return loader
6478

6579

66-
def run_training(checkpoint_name, model_type):
80+
def run_training(checkpoint_name, model_type, train_instance_segmentation):
6781
"""Run the actual model training."""
6882

6983
# All hyperparameters for training.
@@ -74,37 +88,51 @@ def run_training(checkpoint_name, model_type):
7488
n_iterations = 10000 # how long we train (in iterations)
7589

7690
# Get the dataloaders.
77-
train_loader = get_dataloader("train", patch_shape, batch_size)
78-
val_loader = get_dataloader("val", patch_shape, batch_size)
91+
train_loader = get_dataloader("train", patch_shape, batch_size, train_instance_segmentation)
92+
val_loader = get_dataloader("val", patch_shape, batch_size, train_instance_segmentation)
7993

80-
# Get the segment anything model, the optimizer and the LR scheduler
94+
# Get the segment anything model
8195
model = sam_training.get_trainable_sam_model(model_type=model_type, device=device)
82-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
83-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
8496

8597
# This class creates all the training data for a batch (inputs, prompts and labels).
86-
convert_inputs = sam_training.ConvertToSamInputs()
98+
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
99+
100+
# Get the optimizer and the LR scheduler
101+
if train_instance_segmentation:
102+
# for instance segmentation, we use the UNETR model configuration.
103+
unetr = UNETR(
104+
backbone="sam", encoder=model.sam.image_encoder, out_channels=3, use_sam_stats=True,
105+
final_activation="Sigmoid", use_skip_connection=False, resize_input=True,
106+
)
107+
# let's get the parameters for SAM and the decoder from UNETR
108+
joint_model_params = [params for params in model.parameters()] # sam parameters
109+
for name, params in unetr.named_parameters(): # unetr's decoder parameters
110+
if not name.startswith("encoder"):
111+
joint_model_params.append(params)
112+
unetr.to(device)
113+
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
114+
else:
115+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
116+
117+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
87118

88119
# the trainer which performs training and validation (implemented using "torch_em")
89-
trainer = sam_training.SamTrainer(
90-
name=checkpoint_name,
91-
train_loader=train_loader,
92-
val_loader=val_loader,
93-
model=model,
94-
optimizer=optimizer,
95-
# currently we compute loss batch-wise, else we pass channelwise True
96-
loss=torch_em.loss.DiceLoss(channelwise=False),
97-
metric=torch_em.loss.DiceLoss(),
98-
device=device,
99-
lr_scheduler=scheduler,
100-
logger=sam_training.SamLogger,
101-
log_image_interval=100,
102-
mixed_precision=True,
103-
convert_inputs=convert_inputs,
104-
n_objects_per_batch=n_objects_per_batch,
105-
n_sub_iteration=8,
106-
compile_model=False
107-
)
120+
if train_instance_segmentation:
121+
instance_seg_loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)
122+
trainer = sam_training.JointSamTrainer(
123+
name=checkpoint_name, train_loader=train_loader, val_loader=val_loader, model=model,
124+
optimizer=optimizer, device=device, lr_scheduler=scheduler, logger=sam_training.JointSamLogger,
125+
log_image_interval=100, mixed_precision=True, convert_inputs=convert_inputs,
126+
n_objects_per_batch=n_objects_per_batch, n_sub_iteration=8, compile_model=False, unetr=unetr,
127+
instance_loss=instance_seg_loss, instance_metric=instance_seg_loss
128+
)
129+
else:
130+
trainer = sam_training.SamTrainer(
131+
name=checkpoint_name, train_loader=train_loader, val_loader=val_loader, model=model,
132+
optimizer=optimizer, device=device, lr_scheduler=scheduler, logger=sam_training.SamLogger,
133+
log_image_interval=100, mixed_precision=True, convert_inputs=convert_inputs,
134+
n_objects_per_batch=n_objects_per_batch, n_sub_iteration=8, compile_model=False
135+
)
108136
trainer.fit(n_iterations)
109137

110138

@@ -133,7 +161,10 @@ def main():
133161
# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
134162
checkpoint_name = "sam_hela"
135163

136-
run_training(checkpoint_name, model_type)
164+
# Train an additional convolutional decoder for end-to-end automatic instance segmentation
165+
train_instance_segmentation = False
166+
167+
run_training(checkpoint_name, model_type, train_instance_segmentation)
137168
export_model(checkpoint_name, model_type)
138169

139170

micro_sam/sample_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def fetch_tracking_example_data(save_directory: Union[str, os.PathLike]) -> str:
260260
fname = "DIC-C2DH-HeLa.zip"
261261
pooch.retrieve(
262262
url="http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip", # 37 MB
263-
known_hash="fac24746fa0ad5ddf6f27044c785edef36bfa39f7917da4ad79730a7748787af",
263+
known_hash="832fed2d05bb7488cf9c51a2994b75f8f3f53b3c3098856211f2d39023c34e1a",
264264
fname=fname,
265265
path=save_directory,
266266
progressbar=True,

0 commit comments

Comments
 (0)