Skip to content

Commit f798aa1

Browse files
committed
add allow_tracing flag
1 parent 1350721 commit f798aa1

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

bioimageio/core/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,10 @@ class AddWeightsCmd(CmdBase, WithSource, WithSummaryLogging):
754754
verbose: bool = False
755755
"""Log more (error) output."""
756756

757+
tracing: bool = True
758+
"""Allow tracing when converting pytorch_state_dict to torchscript
759+
(still uses scripting if possible)."""
760+
757761
def run(self):
758762
model_descr = ensure_description_is_model(self.descr)
759763
if isinstance(model_descr, v0_4.ModelDescr):
@@ -767,6 +771,7 @@ def run(self):
767771
source_format=self.source_format,
768772
target_format=self.target_format,
769773
verbose=self.verbose,
774+
allow_tracing=self.tracing,
770775
)
771776
if updated_model_descr is None:
772777
return

bioimageio/core/weight_converters/_add_weights.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def add_weights(
2121
source_format: Optional[WeightsFormat] = None,
2222
target_format: Optional[WeightsFormat] = None,
2323
verbose: bool = False,
24+
allow_tracing: bool = True,
2425
) -> Optional[ModelDescr]:
2526
"""Convert model weights to other formats and add them to the model description
2627
@@ -90,7 +91,7 @@ def add_weights(
9091
available.add("torchscript")
9192
missing.discard("torchscript")
9293

93-
if "pytorch_state_dict" in available and "torchscript" in missing:
94+
if allow_tracing and "pytorch_state_dict" in available and "torchscript" in missing:
9495
logger.info(
9596
"Attempting to convert 'pytorch_state_dict' weights to 'torchscript' by tracing."
9697
)
@@ -169,5 +170,9 @@ def add_weights(
169170
# resave model with updated rdf.yaml
170171
_ = save_bioimageio_package_as_folder(model_descr, output_path=output_path)
171172
tested_model_descr = load_description_and_test(model_descr)
172-
assert isinstance(tested_model_descr, ModelDescr)
173+
if not isinstance(tested_model_descr, ModelDescr):
174+
raise RuntimeError(
175+
f"The updated model description at {output_path} did not pass testing."
176+
)
177+
173178
return tested_model_descr

0 commit comments

Comments
 (0)