|
| 1 | +import argparse |
| 2 | +from datetime import datetime |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +from micro_sam.training import default_sam_loader, train_sam |
| 6 | +from train_distance_unet import get_image_and_label_paths, select_paths |
| 7 | + |
| 8 | +ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training" |
| 9 | + |
| 10 | + |
| 11 | +def raw_transform(x): |
| 12 | + x = x.astype("float32") |
| 13 | + min_, max_ = np.percentile(x, 1), np.percentile(x, 99) |
| 14 | + x -= min_ |
| 15 | + x /= max_ |
| 16 | + x = np.clip(x, 0, 1) |
| 17 | + return x * 255 |
| 18 | + |
| 19 | + |
| 20 | +def main(): |
| 21 | + parser = argparse.ArgumentParser() |
| 22 | + parser.add_argument( |
| 23 | + "--root", "-i", help="The root folder with the annotated training crops.", |
| 24 | + default=ROOT_CLUSTER, |
| 25 | + ) |
| 26 | + parser.add_argument( |
| 27 | + "--name", help="Optional name for the model to be trained. If not given the current date is used." |
| 28 | + ) |
| 29 | + parser.add_argument( |
| 30 | + "--n_objects_per_batch", "-n", type=int, default=15, |
| 31 | + help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory." |
| 32 | + "The default value is 15." |
| 33 | + ) |
| 34 | + args = parser.parse_args() |
| 35 | + |
| 36 | + root = args.root |
| 37 | + run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name |
| 38 | + name = f"cochlea_micro_sam_{run_name}" |
| 39 | + n_objects_per_batch = args.n_objects_per_batch |
| 40 | + |
| 41 | + image_paths, label_paths = get_image_and_label_paths(root) |
| 42 | + train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True) |
| 43 | + val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True) |
| 44 | + |
| 45 | + patch_shape = (1, 256, 256) |
| 46 | + max_sampling_attempts = 2500 |
| 47 | + |
| 48 | + train_loader = default_sam_loader( |
| 49 | + raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None, |
| 50 | + patch_shape=patch_shape, with_segmentation_decoder=True, |
| 51 | + raw_transform=raw_transform, |
| 52 | + num_workers=6, batch_size=1, is_train=True, |
| 53 | + max_sampling_attempts=max_sampling_attempts, |
| 54 | + ) |
| 55 | + val_loader = default_sam_loader( |
| 56 | + raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None, |
| 57 | + patch_shape=patch_shape, with_segmentation_decoder=True, |
| 58 | + raw_transform=raw_transform, |
| 59 | + num_workers=6, batch_size=1, is_train=False, |
| 60 | + max_sampling_attempts=max_sampling_attempts, |
| 61 | + ) |
| 62 | + |
| 63 | + train_sam( |
| 64 | + name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader, |
| 65 | + n_epochs=50, n_objects_per_batch=n_objects_per_batch, |
| 66 | + save_root=".", |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +if __name__ == "__main__": |
| 71 | + main() |
0 commit comments