Skip to content

Commit 6a6d6a2

Browse files
committed
add logging to increase_available_weight_formats
1 parent cb47733 commit 6a6d6a2

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

bioimageio/core/weight_converters/_add_weights.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def increase_available_weight_formats(
1414
output_path: DirectoryPath,
1515
source_format: Optional[WeightsFormat] = None,
1616
target_format: Optional[WeightsFormat] = None,
17-
) -> ModelDescr:
17+
) -> Optional[ModelDescr]:
1818
"""Convert model weights to other formats and add them to the model description
1919
2020
Args:
@@ -24,6 +24,10 @@ def increase_available_weight_formats(
2424
target_format: convert to a specific weights format.
2525
Default: attempt to convert to any missing format.
2626
devices: Devices that may be used during conversion.
27+
28+
Returns:
29+
- An updated model description if any converted weights were added.
30+
- `None` if no conversion was possible.
2731
"""
2832
if not isinstance(model_descr, ModelDescr):
2933
raise TypeError(type(model_descr))
@@ -48,6 +52,8 @@ def increase_available_weight_formats(
4852
else:
4953
missing = {target_format}
5054

55+
originally_missing = set(missing)
56+
5157
if "pytorch_state_dict" in available and "onnx" in missing:
5258
from .pytorch_to_onnx import convert
5359

@@ -86,5 +92,10 @@ def increase_available_weight_formats(
8692
+ " if you would like bioimageio.core to support a particular conversion."
8793
)
8894

89-
test_model(model_descr).display()
90-
return model_descr
95+
if originally_missing == missing:
96+
logger.warning("failed to add any converted weights")
97+
return None
98+
else:
99+
logger.info(f"added weights formats {originally_missing - missing}")
100+
test_model(model_descr).display()
101+
return model_descr

tests/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def run_subprocess(
3737
["test", "unet2d_nuclei_broad_model"],
3838
["predict", "--example", "unet2d_nuclei_broad_model"],
3939
["update-format", "unet2d_path_old_version"],
40+
["increase-weight-formats", "unet2d_nuclei_broad_model"],
4041
],
4142
)
4243
def test_cli(

0 commit comments

Comments
 (0)