Skip to content

Commit 48227af

Browse files
committed
Semi-supervised training for IHC
1 parent 2ad5065 commit 48227af

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-05-IHC_semi-supervised"
11+
annotated_folders = ["annotated_train", "empty"]
12+
train_image = []
13+
train_label = []
14+
for folder in annotated_folders:
15+
with os.scandir(os.path.join(root, folder)) as direc:
16+
for entry in direc:
17+
if "annotations" not in entry.name and entry.is_file():
18+
basename = os.path.basename(entry.name)
19+
name_no_extension = ".".join(basename.split(".")[:-1])
20+
label_name = name_no_extension + "_annotations.tif"
21+
train_image.extend(glob(os.path.join(root, folder, entry.name)))
22+
train_label.extend(glob(os.path.join(root, folder, label_name)))
23+
24+
annotated_folders = ["annotated_val"]
25+
val_image = []
26+
val_label = []
27+
for folder in annotated_folders:
28+
with os.scandir(os.path.join(root, folder)) as direc:
29+
for entry in direc:
30+
if "annotations" not in entry.name and entry.is_file():
31+
basename = os.path.basename(entry.name)
32+
name_no_extension = ".".join(basename.split(".")[:-1])
33+
label_name = name_no_extension + "_annotations.tif"
34+
val_image.extend(glob(os.path.join(root, folder, entry.name)))
35+
val_label.extend(glob(os.path.join(root, folder, label_name)))
36+
37+
domain_folders = ["domain_Aleyna", "domain_Lennart"]
38+
paths_domain = []
39+
for folder in domain_folders:
40+
paths_domain.extend(glob(os.path.join(root, folder, "*.tif")))
41+
42+
return train_image, train_label, val_image, val_label, paths_domain[:-2], paths_domain[-2:]
43+
44+
45+
def run_training(name):
46+
patch_shape = (64, 128, 128)
47+
batch_size = 8
48+
49+
super_train_img, super_train_label, super_val_img, super_val_label, unsuper_train, unsuper_val = get_paths()
50+
51+
mean_teacher_training(
52+
name=name,
53+
unsupervised_train_paths=unsuper_train,
54+
unsupervised_val_paths=unsuper_val,
55+
patch_shape=patch_shape,
56+
supervised_train_image_paths=super_train_img,
57+
supervised_val_image_paths=super_val_img,
58+
supervised_train_label_paths=super_train_label,
59+
supervised_val_label_paths=super_val_label,
60+
batch_size=batch_size,
61+
n_iterations=int(1e5),
62+
n_samples_train=1000,
63+
n_samples_val=80,
64+
)
65+
66+
67+
def export_model(name, export_path):
68+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
69+
torch.save(model, export_path)
70+
71+
72+
def main():
73+
import argparse
74+
75+
parser = argparse.ArgumentParser()
76+
parser.add_argument("--export_path")
77+
args = parser.parse_args()
78+
name = "IHC_semi-supervised_2025-05-22"
79+
if args.export_path is None:
80+
run_training(name)
81+
else:
82+
export_model(name, args.export_path)
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)