|
2 | 2 | from datetime import datetime |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | | -import torch_em |
6 | 5 | from micro_sam.training import default_sam_loader, train_sam |
7 | 6 | from train_distance_unet import get_image_and_label_paths, select_paths |
8 | 7 |
|
@@ -36,35 +35,35 @@ def main(): |
36 | 35 |
|
37 | 36 | root = args.root |
38 | 37 | run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name |
39 | | - name = f"cochlea_distance_unet_{run_name}" |
| 38 | + name = f"cochlea_micro_sam_{run_name}" |
40 | 39 | n_objects_per_batch = args.n_objects_per_batch |
41 | 40 |
|
42 | 41 | image_paths, label_paths = get_image_and_label_paths(root) |
43 | 42 | train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True) |
44 | 43 | val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True) |
45 | 44 |
|
46 | 45 | patch_shape = (1, 256, 256) |
47 | | - sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10) |
48 | 46 | max_sampling_attempts = 2500 |
49 | 47 |
|
50 | 48 | train_loader = default_sam_loader( |
51 | 49 | raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None, |
52 | 50 | patch_shape=patch_shape, with_segmentation_decoder=True, |
53 | | - raw_transform=raw_transform, sampler=sampler, min_size=10, |
| 51 | + raw_transform=raw_transform, |
54 | 52 | num_workers=6, batch_size=1, is_train=True, |
55 | 53 | max_sampling_attempts=max_sampling_attempts, |
56 | 54 | ) |
57 | 55 | val_loader = default_sam_loader( |
58 | 56 | raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None, |
59 | 57 | patch_shape=patch_shape, with_segmentation_decoder=True, |
60 | | - raw_transform=raw_transform, sampler=sampler, min_size=10, |
| 58 | + raw_transform=raw_transform, |
61 | 59 | num_workers=6, batch_size=1, is_train=False, |
62 | 60 | max_sampling_attempts=max_sampling_attempts, |
63 | 61 | ) |
64 | 62 |
|
65 | 63 | train_sam( |
66 | 64 | name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader, |
67 | 65 | n_epochs=50, n_objects_per_batch=n_objects_per_batch, |
| 66 | + save_root=".", |
68 | 67 | ) |
69 | 68 |
|
70 | 69 |
|
|
0 commit comments