File tree Expand file tree Collapse file tree 3 files changed +29
-2
lines changed
bioimageio/core/weight_converters Expand file tree Collapse file tree 3 files changed +29
-2
lines changed Original file line number Diff line number Diff line change 1+ from collections import defaultdict
2+ from itertools import chain
3+ from typing import DefaultDict , Dict
4+
5+ from bioimageio .spec .model .v0_5 import ModelDescr
6+
7+
8+ def get_dynamic_axes (model_descr : ModelDescr ):
9+ dynamic_axes : DefaultDict [str , Dict [int , str ]] = defaultdict (dict )
10+ for d in chain (model_descr .inputs , model_descr .outputs ):
11+ for i , ax in enumerate (d .axes ):
12+ if not isinstance (ax .size , int ):
13+ dynamic_axes [str (d .id )][i ] = str (ax .id )
14+
15+ return dynamic_axes
Original file line number Diff line number Diff line change 88from ..backends .pytorch_backend import load_torch_model
99from ..digest_spec import get_member_id , get_test_inputs
1010from ..proc_setup import get_pre_and_postprocessing
11+ from ._utils_onnx import get_dynamic_axes
1112
1213
1314def convert (
1415 model_descr : ModelDescr ,
1516 output_path : Path ,
1617 * ,
1718 verbose : bool = False ,
18- opset_version : int = 20 ,
19+ opset_version : int = 15 ,
1920) -> OnnxWeightsDescr :
2021 """
2122 Convert model weights from the Torchscript state_dict format to the ONNX format.
@@ -63,6 +64,9 @@ def convert(
6364 model ,
6465 tuple (inputs_torch ),
6566 str (output_path ),
67+ input_names = [str (d .id ) for d in model_descr .inputs ],
68+ output_names = [str (d .id ) for d in model_descr .outputs ],
69+ dynamic_axes = get_dynamic_axes (model_descr ),
6670 verbose = verbose ,
6771 opset_version = opset_version ,
6872 )
Original file line number Diff line number Diff line change 55from bioimageio .spec .model .v0_5 import ModelDescr , OnnxWeightsDescr
66from bioimageio .spec .utils import download
77
8+ from .. import __version__
89from ..digest_spec import get_member_id , get_test_inputs
910from ..proc_setup import get_pre_and_postprocessing
11+ from ._utils_onnx import get_dynamic_axes
1012
1113
1214def convert (
@@ -67,10 +69,16 @@ def convert(
6769 model , # type: ignore
6870 tuple (inputs_torch ),
6971 str (output_path ),
72+ input_names = [str (d .id ) for d in model_descr .inputs ],
73+ output_names = [str (d .id ) for d in model_descr .outputs ],
74+ dynamic_axes = get_dynamic_axes (model_descr ),
7075 verbose = verbose ,
7176 opset_version = opset_version ,
7277 )
7378
7479 return OnnxWeightsDescr (
75- source = output_path , parent = "pytorch_state_dict" , opset_version = opset_version
80+ source = output_path ,
81+ parent = "torchscript" ,
82+ opset_version = opset_version ,
83+ comment = f"Converted with bioimageio.core { __version__ } ." ,
7684 )
You can’t perform that action at this time.
0 commit comments