Skip to content

Commit 28c09e4

Browse files
Merge pull request #117 from bioimage-io/fix-build-spec-output-shape
Update build_model for changed output shape spec; add test
2 parents e1e6641 + 5eb07e2 commit 28c09e4

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,16 @@ def _get_input_tensor(test_in, name, step, min_shape, data_range, axes, preproce
179179
return inputs
180180

181181

182-
def _get_output_tensor(test_out, name, reference_input, scale, offset, axes, data_range, postprocessing, halo):
182+
def _get_output_tensor(test_out, name, reference_tensor, scale, offset, axes, data_range, postprocessing, halo):
183183
shape = test_out.shape
184-
if reference_input is None:
184+
if reference_tensor is None:
185185
assert scale is None
186186
assert offset is None
187187
shape_description = shape
188188
else:
189189
assert scale is not None
190190
assert offset is not None
191-
shape_description = {"reference_input": reference_input, "scale": scale, "offset": offset}
191+
shape_description = {"reference_tensor": reference_tensor, "scale": scale, "offset": offset}
192192

193193
axes = _get_axes(axes, test_out.ndim)
194194
data_range = _get_data_range(data_range, test_out.dtype)

tests/build_spec/test_build_spec.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from bioimageio.core.resource_io.io_ import load_raw_resource_description
77

88

9-
def _test_build_spec(path, weight_type, tensorflow_version=None):
9+
def _test_build_spec(path, weight_type, tensorflow_version=None, use_implicit_output_shape=False):
1010
from bioimageio.core.build_spec import build_model
1111

1212
model_spec = load_raw_resource_description(path)
@@ -50,6 +50,11 @@ def _test_build_spec(path, weight_type, tensorflow_version=None):
5050
)
5151
if tensorflow_version is not None:
5252
kwargs["tensorflow_version"] = tensorflow_version
53+
if use_implicit_output_shape:
54+
kwargs["input_name"] = "input"
55+
kwargs["output_reference"] = "input"
56+
kwargs["output_scale"] = [1.0, 1.0, 1.0, 1.0]
57+
kwargs["output_offset"] = [0.0, 0.0, 0.0, 0.0]
5358
raw_model = build_model(**kwargs)
5459
spec.model.schema.Model().dump(raw_model)
5560

@@ -58,6 +63,10 @@ def test_build_spec_pytorch(any_torch_model):
5863
_test_build_spec(any_torch_model, "pytorch_state_dict")
5964

6065

66+
def test_build_spec_implicit_output_shape(unet2d_nuclei_broad_model):
67+
_test_build_spec(unet2d_nuclei_broad_model, "pytorch_state_dict", use_implicit_output_shape=True)
68+
69+
6170
def test_build_spec_torchscript(any_torchscript_model):
6271
_test_build_spec(any_torchscript_model, "pytorch_script")
6372

0 commit comments

Comments
 (0)