11import os
2-
32import numpy as np
3+
44import torch
5+
56import torch_em
7+ from torch_em .model import UNETR
8+ from torch_em .loss import DiceBasedDistanceLoss
9+ from torch_em .transform .label import PerObjectDistanceTransform
610
711import micro_sam .training as sam_training
8- from micro_sam .sample_data import fetch_tracking_example_data , fetch_tracking_segmentation_data
912from micro_sam .util import export_custom_sam_model
13+ from micro_sam .sample_data import fetch_tracking_example_data , fetch_tracking_segmentation_data
14+
1015
1116DATA_FOLDER = "data"
1217
1318
14- def get_dataloader (split , patch_shape , batch_size ):
19+ def get_dataloader (split , patch_shape , batch_size , train_instance_segmentation ):
1520 """Return train or val data loader for finetuning SAM.
1621
1722 The data loader must be a torch data loader that retuns `x, y` tensors,
@@ -52,18 +57,27 @@ def get_dataloader(split, patch_shape, batch_size):
5257 else :
5358 roi = np .s_ [70 :, :, :]
5459
60+ if train_instance_segmentation :
61+ # Computes the distance transform for objects to perform end-to-end automatic instance segmentation.
62+ label_transform = PerObjectDistanceTransform (
63+ distances = True , boundary_distances = True , directed_distances = False ,
64+ foreground = True , instances = True , min_size = 25
65+ )
66+ else :
67+ label_transform = torch_em .transform .label .connected_components
68+
5569 loader = torch_em .default_segmentation_loader (
5670 raw_paths = image_dir , raw_key = raw_key ,
5771 label_paths = segmentation_dir , label_key = label_key ,
5872 patch_shape = patch_shape , batch_size = batch_size ,
5973 ndim = 2 , is_seg_dataset = True , rois = roi ,
60- label_transform = torch_em . transform . label . connected_components ,
74+ label_transform = label_transform ,
6175 num_workers = 8 , shuffle = True , raw_transform = sam_training .identity ,
6276 )
6377 return loader
6478
6579
66- def run_training (checkpoint_name , model_type ):
80+ def run_training (checkpoint_name , model_type , train_instance_segmentation ):
6781 """Run the actual model training."""
6882
6983 # All hyperparameters for training.
@@ -74,37 +88,51 @@ def run_training(checkpoint_name, model_type):
7488 n_iterations = 10000 # how long we train (in iterations)
7589
7690 # Get the dataloaders.
77- train_loader = get_dataloader ("train" , patch_shape , batch_size )
78- val_loader = get_dataloader ("val" , patch_shape , batch_size )
91+ train_loader = get_dataloader ("train" , patch_shape , batch_size , train_instance_segmentation )
92+ val_loader = get_dataloader ("val" , patch_shape , batch_size , train_instance_segmentation )
7993
80- # Get the segment anything model, the optimizer and the LR scheduler
94+ # Get the segment anything model
8195 model = sam_training .get_trainable_sam_model (model_type = model_type , device = device )
82- optimizer = torch .optim .Adam (model .parameters (), lr = 1e-5 )
83- scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = "min" , factor = 0.9 , patience = 10 , verbose = True )
8496
8597 # This class creates all the training data for a batch (inputs, prompts and labels).
86- convert_inputs = sam_training .ConvertToSamInputs ()
98+ convert_inputs = sam_training .ConvertToSamInputs (transform = model .transform , box_distortion_factor = 0.025 )
99+
100+ # Get the optimizer and the LR scheduler
101+ if train_instance_segmentation :
102+ # for instance segmentation, we use the UNETR model configuration.
103+ unetr = UNETR (
104+ backbone = "sam" , encoder = model .sam .image_encoder , out_channels = 3 , use_sam_stats = True ,
105+ final_activation = "Sigmoid" , use_skip_connection = False , resize_input = True ,
106+ )
107+ # let's get the parameters for SAM and the decoder from UNETR
108+ joint_model_params = [params for params in model .parameters ()] # sam parameters
109+ for name , params in unetr .named_parameters (): # unetr's decoder parameters
110+ if not name .startswith ("encoder" ):
111+ joint_model_params .append (params )
112+ unetr .to (device )
113+ optimizer = torch .optim .Adam (joint_model_params , lr = 1e-5 )
114+ else :
115+ optimizer = torch .optim .Adam (model .parameters (), lr = 1e-5 )
116+
117+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = "min" , factor = 0.9 , patience = 10 , verbose = True )
87118
88119 # the trainer which performs training and validation (implemented using "torch_em")
89- trainer = sam_training .SamTrainer (
90- name = checkpoint_name ,
91- train_loader = train_loader ,
92- val_loader = val_loader ,
93- model = model ,
94- optimizer = optimizer ,
95- # currently we compute loss batch-wise, else we pass channelwise True
96- loss = torch_em .loss .DiceLoss (channelwise = False ),
97- metric = torch_em .loss .DiceLoss (),
98- device = device ,
99- lr_scheduler = scheduler ,
100- logger = sam_training .SamLogger ,
101- log_image_interval = 100 ,
102- mixed_precision = True ,
103- convert_inputs = convert_inputs ,
104- n_objects_per_batch = n_objects_per_batch ,
105- n_sub_iteration = 8 ,
106- compile_model = False
107- )
120+ if train_instance_segmentation :
121+ instance_seg_loss = DiceBasedDistanceLoss (mask_distances_in_bg = True )
122+ trainer = sam_training .JointSamTrainer (
123+ name = checkpoint_name , train_loader = train_loader , val_loader = val_loader , model = model ,
124+ optimizer = optimizer , device = device , lr_scheduler = scheduler , logger = sam_training .JointSamLogger ,
125+ log_image_interval = 100 , mixed_precision = True , convert_inputs = convert_inputs ,
126+ n_objects_per_batch = n_objects_per_batch , n_sub_iteration = 8 , compile_model = False , unetr = unetr ,
127+ instance_loss = instance_seg_loss , instance_metric = instance_seg_loss
128+ )
129+ else :
130+ trainer = sam_training .SamTrainer (
131+ name = checkpoint_name , train_loader = train_loader , val_loader = val_loader , model = model ,
132+ optimizer = optimizer , device = device , lr_scheduler = scheduler , logger = sam_training .SamLogger ,
133+ log_image_interval = 100 , mixed_precision = True , convert_inputs = convert_inputs ,
134+ n_objects_per_batch = n_objects_per_batch , n_sub_iteration = 8 , compile_model = False
135+ )
108136 trainer .fit (n_iterations )
109137
110138
@@ -133,7 +161,10 @@ def main():
133161 # The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
134162 checkpoint_name = "sam_hela"
135163
136- run_training (checkpoint_name , model_type )
164+ # Train an additional convolutional decoder for end-to-end automatic instance segmentation
165+ train_instance_segmentation = False
166+
167+ run_training (checkpoint_name , model_type , train_instance_segmentation )
137168 export_model (checkpoint_name , model_type )
138169
139170
0 commit comments