Skip to content

Commit c3b5a26

Browse files
Add model export script
1 parent 21f0138 commit c3b5a26

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

development/export_models.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch_em.util import load_model
3+
4+
5+
def export_sgn():
6+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/v2_cochlea_distance_unet_SGN_supervised_2025-05-27" # noqa
7+
model = load_model(path, device="cpu")
8+
torch.save(model, "SGN.pt")
9+
10+
11+
def export_ihc():
12+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa
13+
model = load_model(path, device="cpu")
14+
torch.save(model, "IHC.pt")
15+
16+
17+
def export_synapses():
18+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/Synapses/synapse_detection_model_v3.pt" # noqa
19+
model = torch.load(path, map_location="cpu", weights_only=False)
20+
torch.save(model, "Synapses.pt")
21+
22+
23+
def export_sgn_lowres():
24+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_sgn-low-res-v4" # noqa
25+
model = load_model(path, device="cpu")
26+
torch.save(model, "SGN-lowres.pt")
27+
28+
29+
def export_ihc_lowres():
30+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/cochlea_distance_unet_ihc-lowres-v3" # noqa
31+
model = load_model(path, device="cpu")
32+
torch.save(model, "IHC-lowres.pt")
33+
34+
35+
def main():
36+
# export_sgn()
37+
# export_ihc()
38+
# export_synapses()
39+
export_sgn_lowres()
40+
# export_ihc_lowres()
41+
42+
43+
if __name__ == "__main__":
44+
main()

0 commit comments

Comments
 (0)