Skip to content

Commit 476e1fe

Browse files
committed
Domain adaptations for SGN segmentation in gerbil
1 parent c3374f7 commit 476e1fe

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN"
11+
folders = ["2025-06-SGN_domain_gerbil_CR"]
12+
train_paths = []
13+
val_paths = []
14+
for folder in folders:
15+
train_paths.extend(glob(os.path.join(root, folder, "train", "*.tif")))
16+
val_paths.extend(glob(os.path.join(root, folder, "val", "*.tif")))
17+
return train_paths, val_paths
18+
19+
20+
def run_training(name):
21+
patch_shape = (64, 128, 128)
22+
batch_size = 8
23+
source_checkpoint = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/v2_cochlea_distance_unet_SGN_supervised_2025-05-27"
24+
25+
train_paths, val_paths = get_paths()
26+
mean_teacher_training(
27+
name=name,
28+
unsupervised_train_paths=train_paths,
29+
unsupervised_val_paths=val_paths,
30+
patch_shape=patch_shape,
31+
source_checkpoint=source_checkpoint,
32+
batch_size=batch_size,
33+
n_iterations=int(2.5e4),
34+
n_samples_train=1000,
35+
n_samples_val=80,
36+
)
37+
38+
39+
def export_model(name, export_path):
40+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
41+
torch.save(model, export_path)
42+
43+
44+
def main():
45+
import argparse
46+
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument("--export_path")
49+
args = parser.parse_args()
50+
name = "sgn-adapted-model_gerbil_CR"
51+
if args.export_path is None:
52+
run_training(name)
53+
else:
54+
export_model(name, args.export_path)
55+
56+
57+
if __name__ == "__main__":
58+
main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN"
11+
folders = ["2025-06-SGN_domain_gerbil_PV"]
12+
train_paths = []
13+
val_paths = []
14+
for folder in folders:
15+
train_paths.extend(glob(os.path.join(root, folder, "train", "*.tif")))
16+
val_paths.extend(glob(os.path.join(root, folder, "val", "*.tif")))
17+
return train_paths, val_paths
18+
19+
20+
def run_training(name):
21+
patch_shape = (64, 128, 128)
22+
batch_size = 8
23+
source_checkpoint = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/v2_cochlea_distance_unet_SGN_supervised_2025-05-27"
24+
25+
train_paths, val_paths = get_paths()
26+
mean_teacher_training(
27+
name=name,
28+
unsupervised_train_paths=train_paths,
29+
unsupervised_val_paths=val_paths,
30+
patch_shape=patch_shape,
31+
source_checkpoint=source_checkpoint,
32+
batch_size=batch_size,
33+
n_iterations=int(2.5e4),
34+
n_samples_train=1000,
35+
n_samples_val=80,
36+
)
37+
38+
39+
def export_model(name, export_path):
40+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
41+
torch.save(model, export_path)
42+
43+
44+
def main():
45+
import argparse
46+
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument("--export_path")
49+
args = parser.parse_args()
50+
name = "sgn-adapted-model_gerbil_PV"
51+
if args.export_path is None:
52+
run_training(name)
53+
else:
54+
export_model(name, args.export_path)
55+
56+
57+
if __name__ == "__main__":
58+
main()

0 commit comments

Comments
 (0)