Skip to content

Commit 6d1ae85

Browse files
Merge pull request #179 from bioimage-io/covers
Add cover generation and default license to build_model, make input/o…
2 parents 1eb7b86 + f2a4da5 commit 6d1ae85

File tree

5 files changed

+246
-63
lines changed

5 files changed

+246
-63
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .build_model import add_weights, build_model
1+
from .add_weights import add_weights
2+
from .build_model import build_model
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
from pathlib import Path
3+
from shutil import copyfile
4+
from typing import Dict, Optional, Union
5+
6+
from bioimageio.core import export_resource_package, load_raw_resource_description
7+
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
8+
from .build_model import _get_weights
9+
10+
11+
def add_weights(
12+
model: Union[RawResourceDescription, os.PathLike, str],
13+
weight_uri: Union[str, Path],
14+
output_path: Union[str, Path],
15+
*,
16+
weight_type: Optional[str] = None,
17+
architecture: Optional[str] = None,
18+
model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None,
19+
tensorflow_version: Optional[str] = None,
20+
opset_version: Optional[str] = None,
21+
**weight_kwargs,
22+
):
23+
"""Add weight entry to bioimage.io model.
24+
25+
Args:
26+
model: the resource description of the model to which the weight format is added
27+
weight_uri: the weight file to be added
28+
output_path: where to serialize the new model with additional weight format
29+
weight_type: the format of the weights to be added
30+
architecture: the file with the source code for the model architecture and the corresponding class.
31+
Only required for models with pytorch_state_dict weight format.
32+
model_kwargs: the keyword arguments for the model class.
33+
Only required for models with pytorch_state_dict weight format.
34+
tensorflow_version: the tensorflow version used for training the model.
35+
Only requred for models with tensorflow or keras weight format.
36+
opset_version: the opset version used in this model.
37+
Only requred for models with onnx weight format.
38+
weight_kwargs: additional keyword arguments for the weight.
39+
"""
40+
model = load_raw_resource_description(model)
41+
42+
# copy the weight path to the input model's root, otherwise it will
43+
# not be found when packaging the new model
44+
weight_out = os.path.join(model.root_path, Path(weight_uri).name)
45+
if Path(weight_out) != Path(weight_uri):
46+
copyfile(weight_uri, weight_out)
47+
48+
new_weights, tmp_arch = _get_weights(
49+
weight_out,
50+
weight_type,
51+
root=Path("."),
52+
architecture=architecture,
53+
model_kwargs=model_kwargs,
54+
tensorflow_version=tensorflow_version,
55+
opset_version=opset_version,
56+
**weight_kwargs,
57+
)
58+
model.weights.update(new_weights)
59+
60+
try:
61+
model_package = export_resource_package(model, output_path=output_path)
62+
model = load_raw_resource_description(model_package)
63+
except Exception as e:
64+
raise e
65+
finally:
66+
if tmp_arch is not None:
67+
os.remove(tmp_arch)
68+
return model

0 commit comments

Comments
 (0)