Skip to content

Commit 923bd0e

Browse files
committed
Domain adaptation for IHC segmentation of MLR215R (CR)
1 parent e17d4f4 commit 923bd0e

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
[
2+
{
3+
"cochlea": "M_LR_000215_R",
4+
"image_channel": [
5+
"CR"
6+
],
7+
"segmentation_channel": "IHC_v4b",
8+
"type": "ihc",
9+
"n_blocks": 6,
10+
"halo_size": [
11+
256,
12+
256,
13+
256
14+
],
15+
"component_list": [
16+
7,
17+
4,
18+
1,
19+
11
20+
],
21+
"crop_centers": [
22+
[
23+
727,
24+
1300,
25+
685
26+
],
27+
[
28+
1123,
29+
931,
30+
870
31+
],
32+
[
33+
812,
34+
475,
35+
1063
36+
],
37+
[
38+
296,
39+
645,
40+
906
41+
],
42+
[
43+
503,
44+
918,
45+
518
46+
],
47+
[
48+
529,
49+
1064,
50+
408
51+
]
52+
]
53+
}
54+
]
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
from glob import glob
3+
4+
import torch
5+
from torch_em.util import load_model
6+
from flamingo_tools.training import mean_teacher_training
7+
8+
9+
def get_paths():
10+
root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC"
11+
folders = ["2025-08-IHC_domain_CR"]
12+
train_paths = []
13+
val_paths = []
14+
for folder in folders:
15+
train_paths.extend(glob(os.path.join(root, folder, "train", "*.tif")))
16+
val_paths.extend(glob(os.path.join(root, folder, "val", "*.tif")))
17+
return train_paths, val_paths
18+
19+
20+
def run_training(name):
21+
patch_shape = (64, 128, 128)
22+
batch_size = 8
23+
source_checkpoint = os.path.join("/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet",
24+
"trained_models/IHC",
25+
"v4_cochlea_distance_unet_IHC_supervised_2025-07-14")
26+
27+
train_paths, val_paths = get_paths()
28+
mean_teacher_training(
29+
name=name,
30+
unsupervised_train_paths=train_paths,
31+
unsupervised_val_paths=val_paths,
32+
patch_shape=patch_shape,
33+
source_checkpoint=source_checkpoint,
34+
batch_size=batch_size,
35+
n_iterations=int(2.5e4),
36+
n_samples_train=1000,
37+
n_samples_val=80,
38+
)
39+
40+
41+
def export_model(name, export_path):
42+
model = load_model(os.path.join("checkpoints", name), state_key="teacher")
43+
torch.save(model, export_path)
44+
45+
46+
def main():
47+
import argparse
48+
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument("--export_path")
51+
args = parser.parse_args()
52+
name = "ihc-adapted-model_M215R_CR"
53+
if args.export_path is None:
54+
run_training(name)
55+
else:
56+
export_model(name, args.export_path)
57+
58+
59+
if __name__ == "__main__":
60+
main()

0 commit comments

Comments
 (0)