Skip to content

Commit 57576c7

Browse files
committed
add test_resource cli
1 parent a721088 commit 57576c7

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

bioimageio/core/__main__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import json
23
import os
34
from glob import glob
@@ -9,6 +10,12 @@
910

1011
from bioimageio.core import __version__, prediction, commands, resource_tests
1112
from bioimageio.spec.__main__ import app
13+
from bioimageio.spec.model.raw_nodes import WeightsFormat
14+
15+
try:
16+
from typing import get_args
17+
except ImportError:
18+
from typing_extensions import get_args # type: ignore
1219

1320
try:
1421
from bioimageio.core.weight_converter import torch as torch_converter
@@ -40,6 +47,9 @@ def package(
4047

4148
# if we want to use something like "choice" for the weight formats, we need to use an enum, see:
4249
# https://github.com/tiangolo/typer/issues/182
50+
WeightFormatEnum = enum.Enum("WeightFormatEnum", get_args(WeightsFormat))
51+
52+
4353
@app.command()
4454
def test_model(
4555
model_rdf: str = typer.Argument(
@@ -65,6 +75,31 @@ def test_model(
6575
test_model.__doc__ = resource_tests.test_model.__doc__
6676

6777

78+
@app.command()
79+
def test_resource(
80+
rdf: str = typer.Argument(
81+
..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package."
82+
),
83+
weight_format: Optional[str] = typer.Argument(None, help="(for model only) The weight format to use."),
84+
devices: Optional[List[str]] = typer.Argument(None, help="(for model only) Devices for running the model."),
85+
decimal: int = typer.Argument(4, help="(for model only) The test precision."),
86+
) -> int:
87+
# this is a weird typer bug: default devices are empty tuple although they should be None
88+
if len(devices) == 0:
89+
devices = None
90+
summary = resource_tests.test_resource(rdf, weight_format=weight_format, devices=devices, decimal=decimal)
91+
if summary["error"] is None:
92+
print(f"Resource test for {rdf} has passed.")
93+
return 0
94+
else:
95+
print(f"Resource test for {rdf} has FAILED!")
96+
print(summary)
97+
return 1
98+
99+
100+
test_resource.__doc__ = resource_tests.test_resource.__doc__
101+
102+
68103
@app.command()
69104
def predict_image(
70105
model_rdf: Path = typer.Argument(

0 commit comments

Comments
 (0)