Skip to content

Commit 65a1aad

Browse files
authored
Merge pull request #147 from bioimage-io/cli_update2
Fix specifying weight format in CLI (#144)
2 parents b423856 + 28053a8 commit 65a1aad

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

bioimageio/core/__main__.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ def test_model(
7777
# this is a weird typer bug: default devices are empty tuple although they should be None
7878
if len(devices) == 0:
7979
devices = None
80-
summary = resource_tests.test_model(model_rdf, weight_format=weight_format, devices=devices, decimal=decimal)
80+
summary = resource_tests.test_model(
81+
model_rdf,
82+
weight_format=None if weight_format is None else weight_format.value,
83+
devices=devices,
84+
decimal=decimal,
85+
)
8186
if summary["error"] is None:
8287
print(f"Model test for {model_rdf} has passed.")
8388
return 0
@@ -102,7 +107,9 @@ def test_resource(
102107
# this is a weird typer bug: default devices are empty tuple although they should be None
103108
if len(devices) == 0:
104109
devices = None
105-
summary = resource_tests.test_resource(rdf, weight_format=weight_format, devices=devices, decimal=decimal)
110+
summary = resource_tests.test_resource(
111+
rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal
112+
)
106113
if summary["error"] is None:
107114
print(f"Resource test for {rdf} has passed.")
108115
return 0
@@ -131,7 +138,7 @@ def predict_image(
131138
# ),
132139
padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."),
133140
tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."),
134-
weight_format: Optional[str] = typer.Option(None, help="The weight format to use."),
141+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
135142
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
136143
) -> int:
137144

@@ -145,7 +152,9 @@ def predict_image(
145152
# this is a weird typer bug: default devices are empty tuple although they should be None
146153
if len(devices) == 0:
147154
devices = None
148-
prediction.predict_image(model_rdf, inputs, outputs, padding, tiling, weight_format, devices)
155+
prediction.predict_image(
156+
model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices
157+
)
149158
return 0
150159

151160

@@ -169,7 +178,7 @@ def predict_images(
169178
# ),
170179
padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."),
171180
tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."),
172-
weight_format: Optional[str] = typer.Option(None, help="The weight format to use."),
181+
weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."),
173182
devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."),
174183
) -> int:
175184
input_files = glob(input_pattern)
@@ -194,7 +203,7 @@ def predict_images(
194203
output_files,
195204
padding=padding,
196205
tiling=tiling,
197-
weight_format=weight_format,
206+
weight_format=None if weight_format is None else weight_format.value,
198207
devices=devices,
199208
verbose=True,
200209
)

tests/test_cli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@ def test_cli_test_model(unet2d_nuclei_broad_model):
1616
assert ret.returncode == 0
1717

1818

19+
def test_cli_test_model_with_specific_weight_format(unet2d_nuclei_broad_model):
20+
ret = subprocess.run(
21+
["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"]
22+
)
23+
assert ret.returncode == 0
24+
25+
26+
def test_cli_test_resource(unet2d_nuclei_broad_model):
27+
ret = subprocess.run(["bioimageio", "test-model", unet2d_nuclei_broad_model])
28+
assert ret.returncode == 0
29+
30+
31+
def test_cli_test_resource_with_specific_weight_format(unet2d_nuclei_broad_model):
32+
ret = subprocess.run(
33+
["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"]
34+
)
35+
assert ret.returncode == 0
36+
37+
1938
def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path):
2039
spec = load_resource_description(unet2d_nuclei_broad_model)
2140
in_path = spec.test_inputs[0]

0 commit comments

Comments
 (0)