Skip to content

Commit c4c0642

Browse files
committed
prefer state dict -> onnx over torchscript -> onnx
1 parent 19edd78 commit c4c0642

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/bioimageio/core/weight_converters/_add_weights.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import traceback
22
from typing import Optional, Union
33

4-
from loguru import logger
5-
from pydantic import DirectoryPath
6-
74
from bioimageio.spec import (
85
InvalidDescr,
96
load_model_description,
107
save_bioimageio_package_as_folder,
118
)
129
from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat
10+
from loguru import logger
11+
from pydantic import DirectoryPath
1312

1413
from .._resource_tests import load_description_and_test
1514

@@ -113,15 +112,17 @@ def add_weights(
113112
available.add("torchscript")
114113
missing.discard("torchscript")
115114

116-
if "torchscript" in available and "onnx" in missing:
117-
logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
118-
from .torchscript_to_onnx import convert
115+
if "pytorch_state_dict" in available and "onnx" in missing:
116+
logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
117+
from .pytorch_to_onnx import convert
119118

120119
try:
121120
onnx_weights_path = output_path / "weights.onnx"
121+
122122
model_descr.weights.onnx = convert(
123123
model_descr,
124124
output_path=onnx_weights_path,
125+
verbose=verbose,
125126
)
126127
except Exception as e:
127128
if verbose:
@@ -132,13 +133,12 @@ def add_weights(
132133
available.add("onnx")
133134
missing.discard("onnx")
134135

135-
if "pytorch_state_dict" in available and "onnx" in missing:
136-
logger.info("Attempting to convert 'pytorch_state_dict' weights to 'onnx'.")
137-
from .pytorch_to_onnx import convert
136+
if "torchscript" in available and "onnx" in missing:
137+
logger.info("Attempting to convert 'torchscript' weights to 'onnx'.")
138+
from .torchscript_to_onnx import convert
138139

139140
try:
140141
onnx_weights_path = output_path / "weights.onnx"
141-
142142
model_descr.weights.onnx = convert(
143143
model_descr,
144144
output_path=onnx_weights_path,

0 commit comments

Comments
 (0)