@@ -62,12 +62,8 @@ def get_dataloader(split, patch_shape, batch_size):
6262 return loader
6363
6464
65- def main ():
66- """Finetune a Segment Anything model.
67-
68- This example uses image data and segmentations from the cell tracking challenge,
69- but can easily be adapted for other data (including data you have annoated with micro_sam beforehand).
70- """
65+ def run_training (checkpoint_name , model_type ):
66+ """Run the actual model training."""
7167
7268 # All hyperparameters for training.
7369 batch_size = 1 # the training batch size
@@ -76,9 +72,6 @@ def main():
7672 device = torch .device ("cuda" ) # the device/GPU used for training
7773 n_iterations = 10000 # how long we train (in iterations)
7874
79- model_type = "vit_b" # the name of the model which is used to initialize the weights that are finetuned
80- # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
81-
8275 # Get the dataloaders.
8376 train_loader = get_dataloader ("train" , patch_shape , batch_size )
8477 val_loader = get_dataloader ("val" , patch_shape , batch_size )
@@ -91,7 +84,6 @@ def main():
9184 # This class creates all the training data for a batch (inputs, prompts and labels).
9285 convert_inputs = sam_training .ConvertToSamInputs ()
9386
94- checkpoint_name = "sam_hela"
9587 # the trainer which performs training and validation (implemented using "torch_em")
9688 trainer = sam_training .SamTrainer (
9789 name = checkpoint_name ,
@@ -114,6 +106,9 @@ def main():
114106 )
115107 trainer .fit (n_iterations )
116108
109+
110+ def export_model (checkpoint_name , model_type ):
111+ """Export the trained model."""
117112 # export the model after training so that it can be used by the rest of the micro_sam library
118113 export_path = "./finetuned_hela_model.pth"
119114 checkpoint_path = os .path .join ("checkpoints" , checkpoint_name , "best.pt" )
@@ -124,5 +119,22 @@ def main():
124119 )
125120
126121
122+ def main ():
123+ """Finetune a Segment Anything model.
124+
125+ This example uses image data and segmentations from the cell tracking challenge,
126+ but can easily be adapted for other data (including data you have annoated with micro_sam beforehand).
127+ """
128+ # The model_type determines which base model is used to initialize the weights that are finetuned.
129+ # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
130+ model_type = "vit_b"
131+
132+ # The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
133+ checkpoint_name = "sam_hela"
134+
135+ run_training (checkpoint_name , model_type )
136+ export_model (checkpoint_name , model_type )
137+
138+
127139if __name__ == "__main__" :
128140 main ()
0 commit comments