|
| 1 | +import os |
| 2 | +from glob import glob |
| 3 | +from natsort import natsorted |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +import micro_sam.training as sam_training |
| 8 | +from micro_sam.util import export_custom_sam_model |
| 9 | + |
| 10 | + |
| 11 | +def train_embl_alm_data(checkpoint_name): |
| 12 | + """Training a MicroSAM model for https://github.com/computational-cell-analytics/micro-sam/issues/1084. |
| 13 | + """ |
| 14 | + # All hyperparameters for training. |
| 15 | + batch_size = 1 |
| 16 | + patch_shape = (512, 512) |
| 17 | + n_objects_per_batch = 25 |
| 18 | + device = torch.device("cuda") |
| 19 | + |
| 20 | + # Get the filepaths to images and corresponding labels. |
| 21 | + image_paths = natsorted(glob(os.path.join(os.getcwd(), "data_same_size", "*.tif"))) |
| 22 | + label_paths = natsorted(glob(os.path.join(os.getcwd(), "masks_same_size", "*.tif"))) |
| 23 | + |
| 24 | + # Next, prepare the dataloaders. |
| 25 | + kwargs = { |
| 26 | + "batch_size": batch_size, |
| 27 | + "patch_shape": patch_shape, |
| 28 | + "with_segmentation_decoder": True, |
| 29 | + "num_workers": 16, |
| 30 | + "shuffle": True, |
| 31 | + } |
| 32 | + |
| 33 | + train_loader = sam_training.default_sam_loader( |
| 34 | + raw_paths=image_paths[:-5], raw_key=None, label_paths=label_paths[:-5], label_key=None, **kwargs, |
| 35 | + ) |
| 36 | + val_loader = sam_training.default_sam_loader( |
| 37 | + raw_paths=image_paths[-5:], raw_key=None, label_paths=label_paths[-5:], label_key=None, **kwargs, |
| 38 | + ) |
| 39 | + |
| 40 | + # Run training. |
| 41 | + sam_training.train_sam( |
| 42 | + name=checkpoint_name, |
| 43 | + model_type="vit_b_lm", |
| 44 | + train_loader=train_loader, |
| 45 | + val_loader=val_loader, |
| 46 | + n_epochs=10, |
| 47 | + n_objects_per_batch=n_objects_per_batch, |
| 48 | + with_segmentation_decoder=True, |
| 49 | + device=device, |
| 50 | + ) |
| 51 | + |
| 52 | + |
| 53 | +def main(): |
| 54 | + checkpoint_name = "sam_embl_alm_fluo" # Name of the checkpoint, stored at "./checkpoints/<CHECKPOINT_NAME>" |
| 55 | + |
| 56 | + train_embl_alm_data(checkpoint_name) |
| 57 | + |
| 58 | + # Export the trained model. |
| 59 | + export_custom_sam_model( |
| 60 | + checkpoint_path=os.path.join("checkpoints", checkpoint_name, "best.pt"), |
| 61 | + model_type="vit_b", |
| 62 | + save_path="./finetuned_embl_alm_fluo_model.pth", |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +if __name__ == "__main__": |
| 67 | + main() |
0 commit comments