Skip to content

Commit df08df7

Browse files
Rename and expose shape check functions
1 parent 605a568 commit df08df7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

bioimageio/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
serialize_raw_resource_description,
1212
)
1313
from .prediction_pipeline import create_prediction_pipeline
14-
from .prediction import predict_image, predict_images
14+
from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling
15+
from .resource_tests import check_input_shape, check_output_shape, test_resource

bioimageio/core/resource_tests.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_model(
6060
)
6161

6262

63-
def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
63+
def check_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
6464
if isinstance(shape_spec, list):
6565
if shape != tuple(shape_spec):
6666
return False
@@ -81,7 +81,7 @@ def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
8181
return True
8282

8383

84-
def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool:
84+
def check_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool:
8585
if isinstance(shape_spec, list):
8686
return shape == tuple(shape_spec)
8787
elif isinstance(shape_spec, ImplicitOutputShape):
@@ -129,7 +129,7 @@ def test_resource(
129129
assert len(inputs) == len(model.inputs) # should be checked by validation
130130
input_shapes = {}
131131
for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)):
132-
if not _validate_input_shape(tuple(ipt.shape), ipt_spec.shape):
132+
if not check_input_shape(tuple(ipt.shape), ipt_spec.shape):
133133
raise ValidationError(
134134
f"Shape {tuple(ipt.shape)} of test input {idx} '{ipt_spec.name}' does not match "
135135
f"input shape description: {ipt_spec.shape}."
@@ -138,7 +138,7 @@ def test_resource(
138138

139139
assert len(expected) == len(model.outputs) # should be checked by validation
140140
for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)):
141-
if not _validate_output_shape(tuple(out.shape), out_spec.shape, input_shapes):
141+
if not check_output_shape(tuple(out.shape), out_spec.shape, input_shapes):
142142
error = (error or "") + (
143143
f"Shape {tuple(out.shape)} of test output {idx} '{out_spec.name}' does not match "
144144
f"output shape description: {out_spec.shape}."

0 commit comments

Comments
 (0)