|
6 | 6 |
|
7 | 7 | import torch.utils.data as data_util |
8 | 8 | from torch_em.data.datasets import get_lizard_dataset |
| 9 | +from torch_em.data.sampler import MinInstanceSampler |
9 | 10 | from micro_sam.util import export_custom_sam_model |
10 | 11 |
|
11 | 12 |
|
12 | 13 | # TODO use other datasets than lizard |
13 | 14 | def get_dataloaders(patch_shape, data_path): |
14 | 15 | label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs |
15 | | - dataset = get_lizard_dataset(path=data_path, patch_shape=patch_shape, label_transform=label_transform) |
| 16 | + sampler = MinInstanceSampler(min_num_instances=5) |
| 17 | + dataset = get_lizard_dataset( |
| 18 | + path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform, |
| 19 | + sampler=sampler, |
| 20 | + ) |
16 | 21 | train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1]) |
17 | | - train_loader = torch_em.get_data_loader(train_ds, batch_size=2) |
| 22 | + train_loader = torch_em.get_data_loader(train_ds, batch_size=1) |
18 | 23 | val_loader = torch_em.get_data_loader(val_ds, batch_size=1) |
19 | 24 | return train_loader, val_loader |
20 | 25 |
|
21 | 26 |
|
22 | | -def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2.5e4), save_root=None): |
| 27 | +def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None): |
23 | 28 | """Example code for finetuning SAM on LiveCELL""" |
24 | 29 |
|
25 | 30 | # training settings: |
26 | 31 | checkpoint_path = None # override this to start training from a custom checkpoint |
27 | 32 | device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu |
28 | 33 | patch_shape = (512, 512) # the patch shape for training |
29 | | - n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled |
| 34 | + n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled |
| 35 | + |
| 36 | + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path) |
30 | 37 |
|
31 | 38 | # get the trainable segment anything model |
32 | 39 | model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device) |
33 | 40 |
|
34 | 41 | # all the stuff we need for training |
35 | 42 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
36 | 43 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) |
37 | | - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path) |
38 | 44 |
|
39 | 45 | # this class creates all the training data for a batch (inputs, prompts and labels) |
40 | 46 | convert_inputs = sam_training.ConvertToSamInputs() |
@@ -74,7 +80,7 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations= |
74 | 80 |
|
75 | 81 |
|
76 | 82 | def main(): |
77 | | - input_path = "" |
| 83 | + input_path = "/scratch-grete/projects/nim00007/data/lizard" |
78 | 84 | export_path = "./sam-vith-histopatho-v1.pth" |
79 | 85 | finetune_histopatho(input_path, export_path) |
80 | 86 |
|
|
0 commit comments