Skip to content

Commit b09d389

Browse files
committed
add check_reproducibility
1 parent b6f84f9 commit b09d389

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed

bioimageio/core/weight_converters/pytorch_to_onnx.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from bioimageio.core.backends.pytorch_backend import load_torch_model
99
from bioimageio.core.digest_spec import get_member_id, get_test_inputs
1010
from bioimageio.core.proc_setup import get_pre_and_postprocessing
11+
from bioimageio.spec._internal.types import AbsoluteTolerance, RelativeTolerance
1112
from bioimageio.spec.model import v0_4, v0_5
1213

1314

@@ -16,10 +17,11 @@ def convert(
1617
*,
1718
output_path: Path,
1819
use_tracing: bool = True,
19-
relative_tolerance: float = 1e-07,
20-
absolute_tolerance: float = 0,
2120
verbose: bool = False,
2221
opset_version: int = 15,
22+
check_reproducibility: bool = True,
23+
relative_tolerance: RelativeTolerance = 1e-07,
24+
absolute_tolerance: AbsoluteTolerance = 0,
2325
) -> v0_5.OnnxWeightsDescr:
2426
"""
2527
Convert model weights from the PyTorch state_dict format to the ONNX format.
@@ -72,7 +74,6 @@ def convert(
7274
outputs_original: List[np.ndarray[Any, Any]] = [
7375
out.numpy() for out in outputs_original_torch
7476
]
75-
7677
if use_tracing:
7778
_ = torch.onnx.export(
7879
model,
@@ -84,35 +85,40 @@ def convert(
8485
else:
8586
raise NotImplementedError
8687

87-
try:
88-
import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
89-
except ImportError:
90-
raise ImportError(
91-
"The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
92-
)
88+
if check_reproducibility:
89+
try:
90+
import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
91+
except ImportError as e:
92+
raise ImportError(
93+
"The onnx weights were exported, but onnx rt is not available"
94+
+ " and weights cannot be checked."
95+
) from e
9396

94-
# check the onnx model
95-
sess = rt.InferenceSession(str(output_path))
96-
onnx_input_node_args = cast(
97-
List[Any], sess.get_inputs()
98-
) # FIXME: remove cast, try using rt.NodeArg instead of Any
99-
inputs_onnx = {
100-
input_name.name: inp
101-
for input_name, inp in zip(onnx_input_node_args, inputs_numpy)
102-
}
103-
outputs_onnx = cast(
104-
Sequence[np.ndarray[Any, Any]], sess.run(None, inputs_onnx)
105-
) # FIXME: remove cast
97+
# check the onnx model
98+
sess = rt.InferenceSession(str(output_path))
99+
onnx_input_node_args = cast(
100+
List[Any], sess.get_inputs()
101+
) # FIXME: remove cast, try using rt.NodeArg instead of Any
102+
inputs_onnx = {
103+
input_name.name: inp
104+
for input_name, inp in zip(onnx_input_node_args, inputs_numpy)
105+
}
106+
outputs_onnx = cast(
107+
Sequence[np.ndarray[Any, Any]], sess.run(None, inputs_onnx)
108+
) # FIXME: remove cast
106109

107-
try:
108-
for out_original, out_onnx in zip(outputs_original, outputs_onnx):
109-
assert_allclose(
110-
out_original, out_onnx, rtol=relative_tolerance, atol=absolute_tolerance
111-
)
112-
except AssertionError as e:
113-
raise AssertionError(
114-
"Inference results of using original and converted weights do not match"
115-
) from e
110+
try:
111+
for out_original, out_onnx in zip(outputs_original, outputs_onnx):
112+
assert_allclose(
113+
out_original,
114+
out_onnx,
115+
rtol=relative_tolerance,
116+
atol=absolute_tolerance,
117+
)
118+
except AssertionError as e:
119+
raise AssertionError(
120+
"Inference results of original and converted weights do not match."
121+
) from e
116122

117123
return v0_5.OnnxWeightsDescr(
118124
source=output_path, parent="pytorch_state_dict", opset_version=opset_version

0 commit comments

Comments
 (0)