Skip to content

Commit f2a4da5

Browse files
Update add_weights functionality and add tests
1 parent ebfa420 commit f2a4da5

File tree

4 files changed

+113
-36
lines changed

4 files changed

+113
-36
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .build_model import add_weights, build_model
1+
from .add_weights import add_weights
2+
from .build_model import build_model
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
from pathlib import Path
3+
from shutil import copyfile
4+
from typing import Dict, Optional, Union
5+
6+
from bioimageio.core import export_resource_package, load_raw_resource_description
7+
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
8+
from .build_model import _get_weights
9+
10+
11+
def add_weights(
12+
model: Union[RawResourceDescription, os.PathLike, str],
13+
weight_uri: Union[str, Path],
14+
output_path: Union[str, Path],
15+
*,
16+
weight_type: Optional[str] = None,
17+
architecture: Optional[str] = None,
18+
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
19+
tensorflow_version: Optional[str] = None,
20+
opset_version: Optional[str] = None,
21+
**weight_kwargs,
22+
):
23+
"""Add weight entry to bioimage.io model.
24+
25+
Args:
26+
model: the resource description of the model to which the weight format is added
27+
weight_uri: the weight file to be added
28+
output_path: where to serialize the new model with additional weight format
29+
weight_type: the format of the weights to be added
30+
architecture: the file with the source code for the model architecture and the corresponding class.
31+
Only required for models with pytorch_state_dict weight format.
32+
model_kwargs: the keyword arguments for the model class.
33+
Only required for models with pytorch_state_dict weight format.
34+
tensorflow_version: the tensorflow version used for training the model.
35+
Only requred for models with tensorflow or keras weight format.
36+
opset_version: the opset version used in this model.
37+
Only requred for models with onnx weight format.
38+
weight_kwargs: additional keyword arguments for the weight.
39+
"""
40+
model = load_raw_resource_description(model)
41+
42+
# copy the weight path to the input model's root, otherwise it will
43+
# not be found when packaging the new model
44+
weight_out = os.path.join(model.root_path, Path(weight_uri).name)
45+
if Path(weight_out) != Path(weight_uri):
46+
copyfile(weight_uri, weight_out)
47+
48+
new_weights, tmp_arch = _get_weights(
49+
weight_out,
50+
weight_type,
51+
root=Path("."),
52+
architecture=architecture,
53+
model_kwargs=model_kwargs,
54+
tensorflow_version=tensorflow_version,
55+
opset_version=opset_version,
56+
**weight_kwargs,
57+
)
58+
model.weights.update(new_weights)
59+
60+
try:
61+
model_package = export_resource_package(model, output_path=output_path)
62+
model = load_raw_resource_description(model_package)
63+
except Exception as e:
64+
raise e
65+
finally:
66+
if tmp_arch is not None:
67+
os.remove(tmp_arch)
68+
return model

bioimageio/core/build_spec/build_model.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,10 @@ def build_model(
616616
documentation: relative file path to markdown documentation for this model.
617617
cite: citations for this model.
618618
output_path: where to save the zipped model package.
619-
source: the file with the source code for the model architecture and the corresponding class.
619+
architecture: the file with the source code for the model architecture and the corresponding class.
620+
Only required for models with pytorch_state_dict weight format.
620621
model_kwargs: the keyword arguments for the model class.
622+
Only required for models with pytorch_state_dict weight format.
621623
weight_type: the type of the weights.
622624
sample_inputs: list of sample inputs to demonstrate the model performance.
623625
sample_outputs: list of sample outputs corresponding to sample_inputs.
@@ -648,8 +650,9 @@ def build_model(
648650
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
649651
add_deepimagej_config: add the deepimagej config to the model.
650652
tensorflow_version: the tensorflow version used for training the model.
651-
Needs to be passed for tensorflow or keras models.
652-
opset_version: the opset version used in this model. Needs to be passed for onnx models.
653+
Only requred for models with tensorflow or keras weight format.
654+
opset_version: the opset version used in this model.
655+
Only requred for models with onnx weight format.
653656
weight_kwargs: additional keyword arguments for this weight type.
654657
"""
655658
if root is None:
@@ -848,35 +851,3 @@ def build_model(
848851

849852
model = load_raw_resource_description(model_package)
850853
return model
851-
852-
853-
def add_weights(
854-
model,
855-
weight_uri: Union[str, Path],
856-
weight_type: Optional[str] = None,
857-
output_path: Optional[Union[str, Path]] = None,
858-
architecture: Optional[str] = None,
859-
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
860-
tensorflow_version: Optional[str] = None,
861-
opset_version: Optional[str] = None,
862-
**weight_kwargs,
863-
):
864-
"""Add weight entry to bioimage.io model."""
865-
# we need to pass the weight path as abs path to avoid confusion with different root directories
866-
new_weights, tmp_arch = _get_weights(
867-
Path(weight_uri).absolute(),
868-
weight_type,
869-
root=Path("."),
870-
architecture=architecture,
871-
model_kwargs=model_kwargs,
872-
tensorflow_version=tensorflow_version,
873-
opset_version=opset_version,
874-
**weight_kwargs,
875-
)
876-
model.weights.update(new_weights)
877-
if output_path is not None:
878-
model_package = export_resource_package(model, output_path=output_path)
879-
model = load_raw_resource_description(model_package)
880-
if tmp_arch is not None:
881-
os.remove(tmp_arch)
882-
return model
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description
2+
3+
4+
def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs):
5+
from bioimageio.core.build_spec import add_weights
6+
7+
rdf = load_raw_resource_description(model)
8+
assert base_weights in rdf.weights
9+
assert added_weights in rdf.weights
10+
11+
weight_path = load_resource_description(model).weights[added_weights].source
12+
assert weight_path.exists()
13+
14+
drop_weights = set(rdf.weights.keys()) - {base_weights}
15+
for drop in drop_weights:
16+
rdf.weights.pop(drop)
17+
assert tuple(rdf.weights.keys()) == (base_weights,)
18+
19+
in_path = tmp_path / "model1.zip"
20+
export_resource_package(rdf, output_path=in_path)
21+
22+
out_path = tmp_path / "model2.zip"
23+
add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs)
24+
25+
assert out_path.exists()
26+
new_rdf = load_resource_description(out_path)
27+
assert set(new_rdf.weights.keys()) == {base_weights, added_weights}
28+
for weight in new_rdf.weights.values():
29+
assert weight.source.exists()
30+
31+
32+
def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path):
33+
_test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript")
34+
35+
36+
def test_add_onnx(unet2d_nuclei_broad_model, tmp_path):
37+
_test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12)

0 commit comments

Comments
 (0)