Skip to content

Commit 4b9c80f

Browse files
Add training_data to build_model
1 parent aa17b2a commit 4b9c80f

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ def build_model(
636636
config: Optional[Dict[str, Any]] = None,
637637
dependencies: Optional[Union[Path, str]] = None,
638638
links: Optional[List[str]] = None,
639+
training_data: Optional[Dict[str, str]] = None,
639640
root: Optional[Union[Path, str]] = None,
640641
add_deepimagej_config: bool = False,
641642
tensorflow_version: Optional[str] = None,
@@ -711,6 +712,7 @@ def build_model(
711712
parent: id of the parent model from which this model is derived and sha256 of the corresponding weight file.
712713
config: custom configuration for this model.
713714
dependencies: relative path to file with dependencies for this model.
715+
training_data: the training data for this model, either id for a bioimageio dataset or a dataset spec.
714716
root: optional root path for relative paths. This can be helpful when building a spec from another model spec.
715717
add_deepimagej_config: add the deepimagej config to the model.
716718
tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights.
@@ -887,10 +889,14 @@ def build_model(
887889

888890
if maintainers is not None:
889891
kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers]
892+
890893
if parent is not None:
891894
assert len(parent) == 2
892895
kwargs["parent"] = {"uri": parent[0], "sha256": parent[1]}
893896

897+
if training_data is not None:
898+
kwargs["training_data"] = training_data
899+
894900
try:
895901
model = model_spec.raw_nodes.Model(
896902
authors=authors,

tests/build_spec/test_build_spec.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def _test_build_spec(
1717
add_deepimagej_config=False,
1818
use_original_covers=False,
1919
use_absoloute_arch_path=False,
20+
training_data=None,
2021
):
2122
from bioimageio.core.build_spec import build_model
2223

@@ -115,6 +116,8 @@ def _test_build_spec(
115116
kwargs["pixel_sizes"] = [{"x": 5.0, "y": 5.0}]
116117
if use_original_covers:
117118
kwargs["covers"] = resolve_source(model_spec.covers, root)
119+
if training_data is not None:
120+
kwargs["training_data"] = training_data
118121

119122
build_model(**kwargs)
120123
assert out_path.exists()
@@ -193,6 +196,21 @@ def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
193196
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", add_deepimagej_config=True)
194197

195198

199+
def test_build_spec_training_data1(unet2d_nuclei_broad_model, tmp_path):
200+
training_data = {"id": "ilastik/stradist_dsb_training_data"}
201+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data)
202+
203+
204+
def test_build_spec_training_data2(unet2d_nuclei_broad_model, tmp_path):
205+
training_data = {
206+
"type": "dataset",
207+
"name": "nucleus-training-data",
208+
"description": "stardist nucleus training data",
209+
"source": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip",
210+
}
211+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data)
212+
213+
196214
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):
197215
_test_build_spec(
198216
unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version="1.12"

0 commit comments

Comments
 (0)