88from bioimageio .core .backends .pytorch_backend import load_torch_model
99from bioimageio .core .digest_spec import get_member_id , get_test_inputs
1010from bioimageio .core .proc_setup import get_pre_and_postprocessing
11+ from bioimageio .spec ._internal .types import AbsoluteTolerance , RelativeTolerance
1112from bioimageio .spec .model import v0_4 , v0_5
1213
1314
@@ -16,10 +17,11 @@ def convert(
1617 * ,
1718 output_path : Path ,
1819 use_tracing : bool = True ,
19- relative_tolerance : float = 1e-07 ,
20- absolute_tolerance : float = 0 ,
2120 verbose : bool = False ,
2221 opset_version : int = 15 ,
22+ check_reproducibility : bool = True ,
23+ relative_tolerance : RelativeTolerance = 1e-07 ,
24+ absolute_tolerance : AbsoluteTolerance = 0 ,
2325) -> v0_5 .OnnxWeightsDescr :
2426 """
2527 Convert model weights from the PyTorch state_dict format to the ONNX format.
@@ -72,7 +74,6 @@ def convert(
7274 outputs_original : List [np .ndarray [Any , Any ]] = [
7375 out .numpy () for out in outputs_original_torch
7476 ]
75-
7677 if use_tracing :
7778 _ = torch .onnx .export (
7879 model ,
@@ -84,35 +85,40 @@ def convert(
8485 else :
8586 raise NotImplementedError
8687
87- try :
88- import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
89- except ImportError :
90- raise ImportError (
91- "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
92- )
88+ if check_reproducibility :
89+ try :
90+ import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
91+ except ImportError as e :
92+ raise ImportError (
93+ "The onnx weights were exported, but onnx rt is not available"
94+ + " and weights cannot be checked."
95+ ) from e
9396
94- # check the onnx model
95- sess = rt .InferenceSession (str (output_path ))
96- onnx_input_node_args = cast (
97- List [Any ], sess .get_inputs ()
98- ) # FIXME: remove cast, try using rt.NodeArg instead of Any
99- inputs_onnx = {
100- input_name .name : inp
101- for input_name , inp in zip (onnx_input_node_args , inputs_numpy )
102- }
103- outputs_onnx = cast (
104- Sequence [np .ndarray [Any , Any ]], sess .run (None , inputs_onnx )
105- ) # FIXME: remove cast
97+ # check the onnx model
98+ sess = rt .InferenceSession (str (output_path ))
99+ onnx_input_node_args = cast (
100+ List [Any ], sess .get_inputs ()
101+ ) # FIXME: remove cast, try using rt.NodeArg instead of Any
102+ inputs_onnx = {
103+ input_name .name : inp
104+ for input_name , inp in zip (onnx_input_node_args , inputs_numpy )
105+ }
106+ outputs_onnx = cast (
107+ Sequence [np .ndarray [Any , Any ]], sess .run (None , inputs_onnx )
108+ ) # FIXME: remove cast
106109
107- try :
108- for out_original , out_onnx in zip (outputs_original , outputs_onnx ):
109- assert_allclose (
110- out_original , out_onnx , rtol = relative_tolerance , atol = absolute_tolerance
111- )
112- except AssertionError as e :
113- raise AssertionError (
114- "Inference results of using original and converted weights do not match"
115- ) from e
110+ try :
111+ for out_original , out_onnx in zip (outputs_original , outputs_onnx ):
112+ assert_allclose (
113+ out_original ,
114+ out_onnx ,
115+ rtol = relative_tolerance ,
116+ atol = absolute_tolerance ,
117+ )
118+ except AssertionError as e :
119+ raise AssertionError (
120+ "Inference results of original and converted weights do not match."
121+ ) from e
116122
117123 return v0_5 .OnnxWeightsDescr (
118124 source = output_path , parent = "pytorch_state_dict" , opset_version = opset_version
0 commit comments