| 
 | 1 | +import argparse  | 
 | 2 | +from datetime import datetime  | 
 | 3 | + | 
 | 4 | +import numpy as np  | 
 | 5 | +import torch_em  | 
 | 6 | +from micro_sam.training import default_sam_loader, train_sam  | 
 | 7 | +from train_distance_unet import get_image_and_label_paths, select_paths  | 
 | 8 | + | 
 | 9 | +ROOT_CLUSTER = "/scratch-grete/usr/nimcpape/data/moser/lightsheet/training"  | 
 | 10 | + | 
 | 11 | + | 
 | 12 | +def raw_transform(x):  | 
 | 13 | +    x = x.astype("float32")  | 
 | 14 | +    min_, max_ = np.percentile(x, 1), np.percentile(x, 99)  | 
 | 15 | +    x -= min_  | 
 | 16 | +    x /= max_  | 
 | 17 | +    x = np.clip(x, 0, 1)  | 
 | 18 | +    return x * 255  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +def main():  | 
 | 22 | +    parser = argparse.ArgumentParser()  | 
 | 23 | +    parser.add_argument(  | 
 | 24 | +        "--root", "-i", help="The root folder with the annotated training crops.",  | 
 | 25 | +        default=ROOT_CLUSTER,  | 
 | 26 | +    )  | 
 | 27 | +    parser.add_argument(  | 
 | 28 | +        "--name", help="Optional name for the model to be trained. If not given the current date is used."  | 
 | 29 | +    )  | 
 | 30 | +    parser.add_argument(  | 
 | 31 | +        "--n_objects_per_batch", "-n", type=int, default=15,  | 
 | 32 | +        help="The number of objects to use during training. Set it to a lower value if you run out of GPU memory."  | 
 | 33 | +        "The default value is 15."  | 
 | 34 | +    )  | 
 | 35 | +    args = parser.parse_args()  | 
 | 36 | + | 
 | 37 | +    root = args.root  | 
 | 38 | +    run_name = datetime.now().strftime("%Y%m%d") if args.name is None else args.name  | 
 | 39 | +    name = f"cochlea_distance_unet_{run_name}"  | 
 | 40 | +    n_objects_per_batch = args.n_objects_per_batch  | 
 | 41 | + | 
 | 42 | +    image_paths, label_paths = get_image_and_label_paths(root)  | 
 | 43 | +    train_image_paths, train_label_paths = select_paths(image_paths, label_paths, split="train", filter_empty=True)  | 
 | 44 | +    val_image_paths, val_label_paths = select_paths(image_paths, label_paths, split="val", filter_empty=True)  | 
 | 45 | + | 
 | 46 | +    patch_shape = (1, 256, 256)  | 
 | 47 | +    sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=10)  | 
 | 48 | +    max_sampling_attempts = 2500  | 
 | 49 | + | 
 | 50 | +    train_loader = default_sam_loader(  | 
 | 51 | +        raw_paths=train_image_paths, raw_key=None, label_paths=train_label_paths, label_key=None,  | 
 | 52 | +        patch_shape=patch_shape, with_segmentation_decoder=True,  | 
 | 53 | +        raw_transform=raw_transform, sampler=sampler, min_size=10,  | 
 | 54 | +        num_workers=6, batch_size=1, is_train=True,  | 
 | 55 | +        max_sampling_attempts=max_sampling_attempts,  | 
 | 56 | +    )  | 
 | 57 | +    val_loader = default_sam_loader(  | 
 | 58 | +        raw_paths=val_image_paths, raw_key=None, label_paths=val_label_paths, label_key=None,  | 
 | 59 | +        patch_shape=patch_shape, with_segmentation_decoder=True,  | 
 | 60 | +        raw_transform=raw_transform, sampler=sampler, min_size=10,  | 
 | 61 | +        num_workers=6, batch_size=1, is_train=False,  | 
 | 62 | +        max_sampling_attempts=max_sampling_attempts,  | 
 | 63 | +    )  | 
 | 64 | + | 
 | 65 | +    train_sam(  | 
 | 66 | +        name=name, model_type="vit_b_lm", train_loader=train_loader, val_loader=val_loader,  | 
 | 67 | +        n_epochs=50, n_objects_per_batch=n_objects_per_batch,  | 
 | 68 | +    )  | 
 | 69 | + | 
 | 70 | + | 
 | 71 | +if __name__ == "__main__":  | 
 | 72 | +    main()  | 
0 commit comments