Skip to content

Commit 539d2c1

Browse files
committed
Semi-supervised training for SGN
1 parent 48227af commit 539d2c1

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/2025-05_semi-supervised"
11+
annotated_folders = ["annotated_2025-02", "annotated_2025-05", "empty_2025-02", "empty_2025-05"]
12+
train_image = []
13+
train_label = []
14+
for folder in annotated_folders:
15+
with os.scandir(os.path.join(root, folder)) as direc:
16+
for entry in direc:
17+
if "annotations" not in entry.name and entry.is_file():
18+
basename = os.path.basename(entry.name)
19+
name_no_extension = ".".join(basename.split(".")[:-1])
20+
label_name = name_no_extension + "_annotations.tif"
21+
train_image.extend(glob(os.path.join(root, folder, entry.name)))
22+
train_label.extend(glob(os.path.join(root, folder, label_name)))
23+
24+
annotated_folders = ["val_data"]
25+
val_image = []
26+
val_label = []
27+
for folder in annotated_folders:
28+
with os.scandir(os.path.join(root, folder)) as direc:
29+
for entry in direc:
30+
if "annotations" not in entry.name and entry.is_file():
31+
basename = os.path.basename(entry.name)
32+
name_no_extension = ".".join(basename.split(".")[:-1])
33+
label_name = name_no_extension + "_annotations.tif"
34+
val_image.extend(glob(os.path.join(root, folder, entry.name)))
35+
val_label.extend(glob(os.path.join(root, folder, label_name)))
36+
37+
domain_folders = ["domain"]
38+
paths_domain = []
39+
for folder in domain_folders:
40+
paths_domain.extend(glob(os.path.join(root, folder, "*.tif")))
41+
42+
return train_image, train_label, val_image, val_label, paths_domain[:-1], paths_domain[-1:]
43+
44+
45+
def run_training(name):
46+
patch_shape = (64, 128, 128)
47+
batch_size = 20
48+
49+
super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths()
50+
51+
print("super_train", len(super_train_img))
52+
print("super_train", len(super_train_label))
53+
54+
print("super_val", len(super_val_img))
55+
print("super_val", len(super_val_label))
56+
57+
print("unsuper",len(unsuper_train))
58+
print("unsuper",len(unsuper_train))
59+
60+
mean_teacher_training(
61+
name=name,
62+
unsupervised_train_paths=unsuper_train,
63+
unsupervised_val_paths=unsuper_val,
64+
patch_shape=patch_shape,
65+
supervised_train_image_paths=super_train_img,
66+
supervised_val_image_paths=super_val_img,
67+
supervised_train_label_paths=super_train_label,
68+
supervised_val_label_paths=super_val_label,
69+
batch_size=batch_size,
70+
n_iterations=int(1e5),
71+
n_samples_train=1000,
72+
n_samples_val=80,
73+
)
74+
75+
76+
def export_model(name, export_path):
77+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
78+
torch.save(model, export_path)
79+
80+
81+
def main():
82+
import argparse
83+
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("--export_path")
86+
parser.add_argument("--model_name", default=None)
87+
args = parser.parse_args()
88+
if args.model_name is None:
89+
name = "SGN_semi-supervised"
90+
else:
91+
name = args.model_name
92+
93+
if args.export_path is None:
94+
run_training(name)
95+
else:
96+
export_model(name, args.export_path)
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)