Skip to content

Commit 39a8493

Browse files
committed
1st implementation of semisupervised DA
1 parent 9221c9e commit 39a8493

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import argparse
3+
from glob import glob
4+
import json
5+
from sklearn.model_selection import train_test_split
6+
import sys
7+
sys.path.append('/user/muth9/u12095/synapse-net')
8+
from synapse_net.training.domain_adaptation import mean_teacher_adaptation
9+
10+
OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/training/semisupervisedDA_cryo"
11+
12+
def _require_train_val_test_split(datasets, train_root, extension = "mrc"):
13+
train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1
14+
15+
def _train_val_test_split(names):
16+
train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
17+
_ratio = test_ratio / (test_ratio + val_ratio)
18+
val, test = train_test_split(test, test_size=_ratio)
19+
return train, val, test
20+
21+
for ds in datasets:
22+
print(f"Processing dataset: {ds}")
23+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
24+
if os.path.exists(split_path):
25+
print(f"Split file already exists: {split_path}")
26+
continue
27+
28+
file_paths = sorted(glob(os.path.join(train_root, ds, f"*.{extension}")))
29+
file_names = [os.path.basename(path) for path in file_paths]
30+
31+
train, val, test = _train_val_test_split(file_names)
32+
33+
with open(split_path, "w") as f:
34+
json.dump({"train": train, "val": val, "test": test}, f)
35+
36+
def _require_train_val_split(datasets, train_root, extension = "mrc"):
37+
train_ratio, val_ratio = 0.8, 0.2
38+
39+
def _train_val_split(names):
40+
train, val = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
41+
return train, val
42+
43+
for ds in datasets:
44+
print(f"Processing dataset: {ds}")
45+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
46+
if os.path.exists(split_path):
47+
print(f"Split file already exists: {split_path}")
48+
continue
49+
50+
file_paths = sorted(glob(os.path.join(train_root, ds, f"*.{extension}")))
51+
file_names = [os.path.basename(path) for path in file_paths]
52+
53+
train, val = _train_val_split(file_names)
54+
55+
with open(split_path, "w") as f:
56+
json.dump({"train": train, "val": val}, f)
57+
58+
def get_paths(split, datasets, train_root, testset=True, extension = "mrc"):
59+
if testset:
60+
_require_train_val_test_split(datasets, train_root, extension)
61+
else:
62+
_require_train_val_split(datasets, train_root, extension)
63+
64+
paths = []
65+
for ds in datasets:
66+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
67+
with open(split_path) as f:
68+
names = json.load(f)[split]
69+
ds_paths = [os.path.join(train_root, ds, name) for name in names]
70+
assert all(os.path.exists(path) for path in ds_paths), f"Some paths do not exist in {ds_paths}"
71+
paths.extend(ds_paths)
72+
73+
return paths
74+
75+
def vesicle_domain_adaptation(teacher_model, testset=True):
76+
# Adjustable parameters
77+
patch_shape = [48, 256, 256]
78+
model_name = "vesicle-semisupervisedDA-cryo-v1"
79+
model_root = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/models_v2/checkpoints/"
80+
checkpoint_path = os.path.join(model_root, teacher_model)
81+
82+
unsupervised_train_root = "/mnt/lustre-emmy-hdd/usr/u12095/cryo-et"
83+
supervised_train_root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/cryoVesNet"
84+
85+
unsupervised_datasets = ["from_portal"]
86+
unsupervised_train_paths = get_paths("train", datasets=unsupervised_datasets, train_root=unsupervised_train_root, testset=testset)
87+
unsupervised_val_paths = get_paths("val", datasets=unsupervised_datasets, train_root=unsupervised_train_root, testset=testset)
88+
89+
supervised_datasets = ["exported"]
90+
supervised_train_paths = get_paths("train", datasets=supervised_datasets, train_root=supervised_train_root, testset=testset, extension = "h5")
91+
supervised_val_paths = get_paths("val", datasets=supervised_datasets, train_root=supervised_train_root, testset=testset, extension = "h5")
92+
93+
mean_teacher_adaptation(
94+
name=model_name,
95+
unsupervised_train_paths=unsupervised_train_paths,
96+
unsupervised_val_paths=unsupervised_val_paths,
97+
raw_key="data",
98+
supervised_train_paths=supervised_train_paths,
99+
supervised_val_paths=supervised_val_paths,
100+
raw_key_supervised = "raw",
101+
label_key="/labels/vesicles",
102+
patch_shape=patch_shape,
103+
save_root="/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/DA",
104+
source_checkpoint=checkpoint_path,
105+
confidence_threshold=0.75,
106+
n_iterations=int(1e3),
107+
)
108+
109+
def main():
110+
parser = argparse.ArgumentParser()
111+
parser.add_argument("-m", "--teacher_model", required=True, help="Name of teacher model")
112+
parser.add_argument("-t", "--testset", action="store_false", help="Set to False if no testset should be created")
113+
args = parser.parse_args()
114+
115+
vesicle_domain_adaptation(args.teacher_model, args.testset)
116+
117+
if __name__ == "__main__":
118+
main()

0 commit comments

Comments
 (0)