Skip to content

Commit 6e7bab5

Browse files
committed
fix io names
1 parent e90c621 commit 6e7bab5

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,12 @@ def build_model(
445445
sample_inputs: Optional[List[str]] = None,
446446
sample_outputs: Optional[List[str]] = None,
447447
# tensor specific
448-
input_name: Optional[List[str]] = None,
448+
input_names: Optional[List[str]] = None,
449449
input_step: Optional[List[List[int]]] = None,
450450
input_min_shape: Optional[List[List[int]]] = None,
451451
input_axes: Optional[List[str]] = None,
452452
input_data_range: Optional[List[List[Union[int, str]]]] = None,
453-
output_name: Optional[List[str]] = None,
453+
output_names: Optional[List[str]] = None,
454454
output_reference: Optional[List[str]] = None,
455455
output_scale: Optional[List[List[int]]] = None,
456456
output_offset: Optional[List[List[int]]] = None,
@@ -514,12 +514,12 @@ def build_model(
514514
weight_type: the type of the weights.
515515
sample_inputs: list of sample inputs to demonstrate the model performance.
516516
sample_outputs: list of sample outputs corresponding to sample_inputs.
517-
input_name: name of the input tensor.
517+
input_names: names of the input tensors.
518518
input_step: minimal valid increase of the input tensor shape.
519519
input_min_shape: minimal input tensor shape.
520520
input_axes: axes names for the input tensor.
521521
input_data_range: valid data range for the input tensor.
522-
output_name: name of the output tensor.
522+
output_names: names of the output tensors.
523523
output_reference: name of the input reference tensor used to cimpute the output tensor shape.
524524
output_scale: multiplicative factor to compute the output tensor shape.
525525
output_offset: additive term to compute the output tensor shape.
@@ -555,7 +555,11 @@ def build_model(
555555
test_outputs = _ensure_local_or_url(test_outputs, root)
556556

557557
n_inputs = len(test_inputs)
558-
input_name = n_inputs * [None] if input_name is None else input_name
558+
if input_names is None:
559+
input_names = [f"input{i}" for i in range(n_inputs)]
560+
else:
561+
assert len(input_names) == len(test_inputs)
562+
559563
input_step = n_inputs * [None] if input_step is None else input_step
560564
input_min_shape = n_inputs * [None] if input_min_shape is None else input_min_shape
561565
input_axes = n_inputs * [None] if input_axes is None else input_axes
@@ -565,12 +569,16 @@ def build_model(
565569
inputs = [
566570
_get_input_tensor(root / test_in, name, step, min_shape, data_range, axes, preproc)
567571
for test_in, name, step, min_shape, axes, data_range, preproc in zip(
568-
test_inputs, input_name, input_step, input_min_shape, input_axes, input_data_range, preprocessing
572+
test_inputs, input_names, input_step, input_min_shape, input_axes, input_data_range, preprocessing
569573
)
570574
]
571575

572576
n_outputs = len(test_outputs)
573-
output_name = n_outputs * [None] if output_name is None else output_name
577+
if output_names is None:
578+
output_names = [f"output{i}" for i in range(n_outputs)]
579+
else:
580+
assert len(output_names) == len(test_outputs)
581+
574582
output_reference = n_outputs * [None] if output_reference is None else output_reference
575583
output_scale = n_outputs * [None] if output_scale is None else output_scale
576584
output_offset = n_outputs * [None] if output_offset is None else output_offset
@@ -583,7 +591,7 @@ def build_model(
583591
_get_output_tensor(root / test_out, name, reference, scale, offset, axes, data_range, postproc, hal)
584592
for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in zip(
585593
test_outputs,
586-
output_name,
594+
output_names,
587595
output_reference,
588596
output_scale,
589597
output_offset,

0 commit comments

Comments
 (0)