Skip to content

Commit 8a8cb00

Browse files
committed
add test_commands
1 parent 815d87f commit 8a8cb00

File tree

2 files changed

+47
-52
lines changed

2 files changed

+47
-52
lines changed

bioimageio/core/commands.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
"""These functions implement the logic of the bioimageio command line interface
2-
defined in the `cli` module."""
1+
"""deprecated,
2+
use the CLI object `bioimageio.core.cli.Bioimageio` programmatically instead.
3+
"""
34

45
import sys
56
from pathlib import Path
@@ -97,53 +98,3 @@ def package(
9798
output_path=path,
9899
weights_priority_order=weights_priority_order,
99100
)
100-
101-
102-
# TODO: add convert command(s)
103-
# if torch_converter is not None:
104-
105-
# @app.command()
106-
# def convert_torch_weights_to_onnx(
107-
# model_rdf: Path = typer.Argument(
108-
# ..., help="Path to the model resource description file (rdf.yaml) or zipped model."
109-
# ),
110-
# output_path: Path = typer.Argument(..., help="Where to save the onnx weights."),
111-
# opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."),
112-
# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
113-
# verbose: bool = typer.Option(True, help="Verbosity"),
114-
# ):
115-
# ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose)
116-
# sys.exit(ret_code)
117-
118-
# convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__
119-
120-
# @app.command()
121-
# def convert_torch_weights_to_torchscript(
122-
# model_rdf: Path = typer.Argument(
123-
# ..., help="Path to the model resource description file (rdf.yaml) or zipped model."
124-
# ),
125-
# output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
126-
# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
127-
# ):
128-
# torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
129-
# sys.exit(0)
130-
131-
# convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__
132-
133-
134-
# if keras_converter is not None:
135-
136-
# @app.command()
137-
# def convert_keras_weights_to_tensorflow(
138-
# model_rdf: Annotated[
139-
# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.")
140-
# ],
141-
# output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")],
142-
# ):
143-
# rd = load_description(model_rdf)
144-
# ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path)
145-
# sys.exit(ret_code)
146-
147-
# convert_keras_weights_to_tensorflow.__doc__ = (
148-
# keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__
149-
# )

tests/test_commands.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from pathlib import Path
2+
from typing import Literal, Optional
3+
4+
import pytest
5+
6+
from bioimageio.core import load_model
7+
from bioimageio.core.commands import package, validate_format
8+
from bioimageio.core.commands import test as command_tst
9+
from bioimageio.spec.model import ModelDescr
10+
11+
12+
@pytest.mark.fixture(scope="module")
13+
def model(unet2d_nuclei_broad_model: str):
14+
return load_model(unet2d_nuclei_broad_model, perform_io_checks=False)
15+
16+
17+
@pytest.mark.parametrize(
18+
"weight_format",
19+
[
20+
"all",
21+
"pytorch_state_dict",
22+
],
23+
)
24+
def test_package(
25+
weight_format: Literal["all", "pytorch_state_dict"],
26+
model: ModelDescr,
27+
tmp_path: Path,
28+
):
29+
_ = package(model, weight_format=weight_format, path=tmp_path / "out.zip")
30+
31+
32+
def test_validate_format(model: ModelDescr):
33+
_ = validate_format(model)
34+
35+
36+
@pytest.mark.parametrize(
37+
"weight_format,devices", [("all", None), ("pytorch_state_dict", "cpu")]
38+
)
39+
def test_test(
40+
weight_format: Literal["all", "pytorch_state_dict"],
41+
devices: Optional[str],
42+
model: ModelDescr,
43+
):
44+
_ = command_tst(model, weight_format=weight_format, devices=devices)

0 commit comments

Comments
 (0)