-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexport_model.py
More file actions
29 lines (24 loc) · 898 Bytes
/
export_model.py
File metadata and controls
29 lines (24 loc) · 898 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import celldetection as cd
import torch
import argparse
from os.path import join, dirname, basename, isfile
parser = argparse.ArgumentParser('Export model', add_help=False)
parser.add_argument('-i', '--input', type=str, help='Input (filename).')
parser.add_argument('-o', '--output', type=str, help='Output (directory).')
args = parser.parse_args()
model_name = args.input
model = torch.load(model_name, map_location='cpu')
config_name = join(dirname(model_name), 'config.json')
if not isfile(config_name):
config_name = join(dirname(model_name), 'config_r0.json')
conf = cd.Config.from_json(join(dirname(model_name), 'config.json'))
bn = basename(model_name)
tag = basename(dirname(model_name))
export_name = f'{tag}_{bn}'
dst = join(args.output, export_name)
if not isfile(dst):
print(dst)
torch.save(dict(
state_dict=model.state_dict(),
config=conf,
), dst)