Skip to content

Commit 1481244

Browse files
Add domain adaptation script
1 parent aee6121 commit 1481244

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training.domain_adaptation import mean_teacher_adaptation
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
11+
folders = ["fHC", "iDISCO", "microwave-fHC", "microwave-iDISCO"]
12+
paths = []
13+
for folder in folders:
14+
paths.extend(glob(os.path.join(root, folder, "*.tif")))
15+
return paths[:-1], paths[-1:]
16+
17+
18+
def run_training(name):
19+
patch_shape = (64, 128, 128)
20+
batch_size = 8
21+
source_checkpoint = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa
22+
23+
train_paths, val_paths = get_paths()
24+
mean_teacher_adaptation(
25+
name=name,
26+
unsupervised_train_paths=train_paths,
27+
unsupervised_val_paths=val_paths,
28+
patch_shape=patch_shape,
29+
source_checkpoint=source_checkpoint,
30+
batch_size=batch_size,
31+
n_iterations=int(2.5e4),
32+
n_samples_train=1000,
33+
n_samples_val=80,
34+
)
35+
36+
37+
def export_model(name, export_path):
38+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
39+
torch.save(model, export_path)
40+
41+
42+
def main():
43+
import argparse
44+
45+
parser = argparse.ArgumentParser()
46+
parser.add_argument("--export_path")
47+
args = parser.parse_args()
48+
name = "sgn-adapted-model"
49+
if args.export_path is None:
50+
run_training(name)
51+
else:
52+
export_model(name, args.export_path)
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)