|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +from glob import glob |
| 4 | +import json |
| 5 | +from sklearn.model_selection import train_test_split |
| 6 | +import sys |
| 7 | +sys.path.append('/user/muth9/u12095/synapse-net') |
| 8 | +from synapse_net.training.domain_adaptation import mean_teacher_adaptation |
| 9 | + |
| 10 | +OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/training/semisupervisedDA_cryo" |
| 11 | + |
| 12 | +def _require_train_val_test_split(datasets, train_root, extension = "mrc"): |
| 13 | + train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1 |
| 14 | + |
| 15 | + def _train_val_test_split(names): |
| 16 | + train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True) |
| 17 | + _ratio = test_ratio / (test_ratio + val_ratio) |
| 18 | + val, test = train_test_split(test, test_size=_ratio) |
| 19 | + return train, val, test |
| 20 | + |
| 21 | + for ds in datasets: |
| 22 | + print(f"Processing dataset: {ds}") |
| 23 | + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") |
| 24 | + if os.path.exists(split_path): |
| 25 | + print(f"Split file already exists: {split_path}") |
| 26 | + continue |
| 27 | + |
| 28 | + file_paths = sorted(glob(os.path.join(train_root, ds, f"*.{extension}"))) |
| 29 | + file_names = [os.path.basename(path) for path in file_paths] |
| 30 | + |
| 31 | + train, val, test = _train_val_test_split(file_names) |
| 32 | + |
| 33 | + with open(split_path, "w") as f: |
| 34 | + json.dump({"train": train, "val": val, "test": test}, f) |
| 35 | + |
| 36 | +def _require_train_val_split(datasets, train_root, extension = "mrc"): |
| 37 | + train_ratio, val_ratio = 0.8, 0.2 |
| 38 | + |
| 39 | + def _train_val_split(names): |
| 40 | + train, val = train_test_split(names, test_size=1 - train_ratio, shuffle=True) |
| 41 | + return train, val |
| 42 | + |
| 43 | + for ds in datasets: |
| 44 | + print(f"Processing dataset: {ds}") |
| 45 | + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") |
| 46 | + if os.path.exists(split_path): |
| 47 | + print(f"Split file already exists: {split_path}") |
| 48 | + continue |
| 49 | + |
| 50 | + file_paths = sorted(glob(os.path.join(train_root, ds, f"*.{extension}"))) |
| 51 | + file_names = [os.path.basename(path) for path in file_paths] |
| 52 | + |
| 53 | + train, val = _train_val_split(file_names) |
| 54 | + |
| 55 | + with open(split_path, "w") as f: |
| 56 | + json.dump({"train": train, "val": val}, f) |
| 57 | + |
| 58 | +def get_paths(split, datasets, train_root, testset=True, extension = "mrc"): |
| 59 | + if testset: |
| 60 | + _require_train_val_test_split(datasets, train_root, extension) |
| 61 | + else: |
| 62 | + _require_train_val_split(datasets, train_root, extension) |
| 63 | + |
| 64 | + paths = [] |
| 65 | + for ds in datasets: |
| 66 | + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") |
| 67 | + with open(split_path) as f: |
| 68 | + names = json.load(f)[split] |
| 69 | + ds_paths = [os.path.join(train_root, ds, name) for name in names] |
| 70 | + assert all(os.path.exists(path) for path in ds_paths), f"Some paths do not exist in {ds_paths}" |
| 71 | + paths.extend(ds_paths) |
| 72 | + |
| 73 | + return paths |
| 74 | + |
| 75 | +def vesicle_domain_adaptation(teacher_model, testset=True): |
| 76 | + # Adjustable parameters |
| 77 | + patch_shape = [48, 256, 256] |
| 78 | + model_name = "vesicle-semisupervisedDA-cryo-v1" |
| 79 | + model_root = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/models_v2/checkpoints/" |
| 80 | + checkpoint_path = os.path.join(model_root, teacher_model) |
| 81 | + |
| 82 | + unsupervised_train_root = "/mnt/lustre-emmy-hdd/usr/u12095/cryo-et" |
| 83 | + supervised_train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/cryoVesNet" |
| 84 | + |
| 85 | + unsupervised_datasets = ["from_portal"] |
| 86 | + unsupervised_train_paths = get_paths("train", datasets=unsupervised_datasets, train_root=unsupervised_train_root, testset=testset) |
| 87 | + unsupervised_val_paths = get_paths("val", datasets=unsupervised_datasets, train_root=unsupervised_train_root, testset=testset) |
| 88 | + |
| 89 | + supervised_datasets = ["exported"] |
| 90 | + supervised_train_paths = get_paths("train", datasets=supervised_datasets, train_root=supervised_train_root, testset=testset, extension = "h5") |
| 91 | + supervised_val_paths = get_paths("val", datasets=supervised_datasets, train_root=supervised_train_root, testset=testset, extension = "h5") |
| 92 | + |
| 93 | + mean_teacher_adaptation( |
| 94 | + name=model_name, |
| 95 | + unsupervised_train_paths=unsupervised_train_paths, |
| 96 | + unsupervised_val_paths=unsupervised_val_paths, |
| 97 | + raw_key="data", |
| 98 | + supervised_train_paths=supervised_train_paths, |
| 99 | + supervised_val_paths=supervised_val_paths, |
| 100 | + raw_key_supervised = "raw", |
| 101 | + label_key="/labels/vesicles", |
| 102 | + patch_shape=patch_shape, |
| 103 | + save_root="/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/DA", |
| 104 | + source_checkpoint=checkpoint_path, |
| 105 | + confidence_threshold=0.75, |
| 106 | + n_iterations=int(1e3), |
| 107 | + ) |
| 108 | + |
| 109 | +def main(): |
| 110 | + parser = argparse.ArgumentParser() |
| 111 | + parser.add_argument("-m", "--teacher_model", required=True, help="Name of teacher model") |
| 112 | + parser.add_argument("-t", "--testset", action="store_false", help="Set to False if no testset should be created") |
| 113 | + args = parser.parse_args() |
| 114 | + |
| 115 | + vesicle_domain_adaptation(args.teacher_model, args.testset) |
| 116 | + |
| 117 | +if __name__ == "__main__": |
| 118 | + main() |
0 commit comments