Skip to content

Commit c34d219

Browse files
Fix bug in micro_sam.util and refactor finetuning example
1 parent 6f40355 commit c34d219

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

examples/finetuning/finetune_hela.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,8 @@ def get_dataloader(split, patch_shape, batch_size):
6262
return loader
6363

6464

65-
def main():
66-
"""Finetune a Segment Anything model.
67-
68-
This example uses image data and segmentations from the cell tracking challenge,
69-
but can easily be adapted for other data (including data you have annoated with micro_sam beforehand).
70-
"""
65+
def run_training(checkpoint_name, model_type):
66+
"""Run the actual model training."""
7167

7268
# All hyperparameters for training.
7369
batch_size = 1 # the training batch size
@@ -76,9 +72,6 @@ def main():
7672
device = torch.device("cuda") # the device/GPU used for training
7773
n_iterations = 10000 # how long we train (in iterations)
7874

79-
model_type = "vit_b" # the name of the model which is used to initialize the weights that are finetuned
80-
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
81-
8275
# Get the dataloaders.
8376
train_loader = get_dataloader("train", patch_shape, batch_size)
8477
val_loader = get_dataloader("val", patch_shape, batch_size)
@@ -91,7 +84,6 @@ def main():
9184
# This class creates all the training data for a batch (inputs, prompts and labels).
9285
convert_inputs = sam_training.ConvertToSamInputs()
9386

94-
checkpoint_name = "sam_hela"
9587
# the trainer which performs training and validation (implemented using "torch_em")
9688
trainer = sam_training.SamTrainer(
9789
name=checkpoint_name,
@@ -114,6 +106,9 @@ def main():
114106
)
115107
trainer.fit(n_iterations)
116108

109+
110+
def export_model(checkpoint_name, model_type):
111+
"""Export the trained model."""
117112
# export the model after training so that it can be used by the rest of the micro_sam library
118113
export_path = "./finetuned_hela_model.pth"
119114
checkpoint_path = os.path.join("checkpoints", checkpoint_name, "best.pt")
@@ -124,5 +119,22 @@ def main():
124119
)
125120

126121

122+
def main():
123+
"""Finetune a Segment Anything model.
124+
125+
This example uses image data and segmentations from the cell tracking challenge,
126+
but can easily be adapted for other data (including data you have annoated with micro_sam beforehand).
127+
"""
128+
# The model_type determines which base model is used to initialize the weights that are finetuned.
129+
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
130+
model_type = "vit_b"
131+
132+
# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
133+
checkpoint_name = "sam_hela"
134+
135+
run_training(checkpoint_name, model_type)
136+
export_model(checkpoint_name, model_type)
137+
138+
127139
if __name__ == "__main__":
128140
main()

micro_sam/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def get_custom_sam_model(
191191
custom_pickle = pickle
192192
custom_pickle.Unpickler = _CustomUnpickler
193193

194-
device = "cuda" if torch.cuda.is_available() else "cpu"
194+
if device is None:
195+
device = "cuda" if torch.cuda.is_available() else "cpu"
195196
sam = sam_model_registry[model_type]()
196197

197198
# load the model state, ignoring any attributes that can't be found by pickle

0 commit comments

Comments
 (0)