Skip to content

Commit b6f84f9

Browse files
committed
annotate relative and absolute tolerance
1 parent 3009f5b commit b6f84f9

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from bioimageio.spec._internal.common_nodes import ResourceDescrBase
3838
from bioimageio.spec._internal.io import is_yaml_value
3939
from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
40+
from bioimageio.spec._internal.types import AbsoluteTolerance, RelativeTolerance
4041
from bioimageio.spec._internal.validation_context import validation_context_var
4142
from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
4243
from bioimageio.spec.model import v0_4, v0_5
@@ -120,8 +121,8 @@ def test_model(
120121
source: Union[v0_4.ModelDescr, v0_5.ModelDescr, PermissiveFileSource],
121122
weight_format: Optional[WeightsFormat] = None,
122123
devices: Optional[List[str]] = None,
123-
absolute_tolerance: float = 1.5e-4,
124-
relative_tolerance: float = 1e-4,
124+
absolute_tolerance: AbsoluteTolerance = 1.5e-4,
125+
relative_tolerance: RelativeTolerance = 1e-4,
125126
decimal: Optional[int] = None,
126127
*,
127128
determinism: Literal["seed_only", "full"] = "seed_only",
@@ -152,8 +153,8 @@ def test_description(
152153
format_version: Union[Literal["discover", "latest"], str] = "discover",
153154
weight_format: Optional[WeightsFormat] = None,
154155
devices: Optional[Sequence[str]] = None,
155-
absolute_tolerance: float = 1.5e-4,
156-
relative_tolerance: float = 1e-4,
156+
absolute_tolerance: AbsoluteTolerance = 1.5e-4,
157+
relative_tolerance: RelativeTolerance = 1e-4,
157158
decimal: Optional[int] = None,
158159
determinism: Literal["seed_only", "full"] = "seed_only",
159160
expected_type: Optional[str] = None,
@@ -236,8 +237,8 @@ def _test_in_env(
236237
weight_format: Optional[WeightsFormat],
237238
conda_env: Optional[BioimageioCondaEnv],
238239
devices: Optional[Sequence[str]],
239-
absolute_tolerance: float,
240-
relative_tolerance: float,
240+
absolute_tolerance: AbsoluteTolerance,
241+
relative_tolerance: RelativeTolerance,
241242
determinism: Literal["seed_only", "full"],
242243
run_command: Callable[[Sequence[str]], None],
243244
) -> ValidationSummary:
@@ -354,8 +355,8 @@ def load_description_and_test(
354355
format_version: Union[Literal["discover", "latest"], str] = "discover",
355356
weight_format: Optional[WeightsFormat] = None,
356357
devices: Optional[Sequence[str]] = None,
357-
absolute_tolerance: float = 1.5e-4,
358-
relative_tolerance: float = 1e-4,
358+
absolute_tolerance: AbsoluteTolerance = 1.5e-4,
359+
relative_tolerance: RelativeTolerance = 1e-4,
359360
decimal: Optional[int] = None,
360361
determinism: Literal["seed_only", "full"] = "seed_only",
361362
expected_type: Optional[str] = None,

bioimageio/core/weight_converters/_add_weights.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
1-
from abc import ABC
2-
from typing import Optional, Sequence, Union, assert_never, final
1+
from copy import deepcopy
2+
from pathlib import Path
3+
from typing import List, Optional, Sequence, Union
34

45
from bioimageio.spec.model import v0_4, v0_5
56

67

78
def increase_available_weight_formats(
8-
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
9-
source_format: v0_5.WeightsFormat,
10-
target_format: v0_5.WeightsFormat,
9+
model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr],
1110
*,
11+
source_format: Optional[v0_5.WeightsFormat] = None,
12+
target_format: Optional[v0_5.WeightsFormat] = None,
13+
output_path: Path,
1214
devices: Optional[Sequence[str]] = None,
13-
):
14-
if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
15+
) -> Union[v0_4.ModelDescr, v0_5.ModelDescr]:
16+
"""Convert neural network weights to other formats and add them to the model description"""
17+
if not isinstance(model_descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
1518
raise TypeError(
16-
f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
19+
f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_descr)}"
1720
)
1821

19-
if (source_format, target_format) == ("pytorch_state_dict", "onnx"):
20-
from .pytorch_to_onnx import convert_pytorch_to_onnx
22+
if source_format is None:
23+
available = [wf for wf, w in model_descr.weights if w is not None]
24+
missing = [wf for wf, w in model_descr.weights if w is None]
25+
else:
26+
available = [source_format]
27+
missing = [target_format]
28+
29+
if "pytorch_state_dict" in available and "onnx" in missing:
30+
from .pytorch_to_onnx import convert
31+
32+
onnx = convert(model_descr)
2133

2234
else:
2335
raise NotImplementedError(

0 commit comments

Comments
 (0)