Skip to content

Commit 16b308f

Browse files
Update transform_output_tensor function to be more similar to transform_input_image
1 parent 1dd6975 commit 16b308f

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

bioimageio/core/image_helper.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None):
17-
"""Transform input image to adhere to the axes spec defined by a bioimage.io model.
17+
"""Transform input image into output tensor with desired axes.
1818
1919
Args:
2020
image: the input image
@@ -51,22 +51,21 @@ def _drop_axis_default(axis_name, axis_len):
5151
return axis_len // 2 if axis_name in "zyx" else 0
5252

5353

54-
def transform_output_image(tensor: np.ndarray, spec, output_axes: str, drop_function=_drop_axis_default):
55-
"""Transform output tensor to image with the desired axes.
54+
def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default):
55+
"""Transform output tensor into image with desired axes.
5656
5757
Args:
58-
tensor the output tensor
59-
spec: bioimageio model spec
58+
tensor: the output tensor
59+
tensor_axes: bioimageio model spec
6060
output_axes: the desired output axes
6161
drop_function: function that determines how to drop unwanted axes
6262
"""
63-
axes = spec["axes"]
64-
shape = {ax_name: sh for ax_name, sh in zip(axes, tensor.shape)}
65-
if len(axes) != tensor.ndim:
66-
raise ValueError(f"Number of axes {len(axes)} and dimension of tensor {tensor.ndim} don't match")
67-
output = DataArray(tensor, dims=tuple(axes))
63+
if len(tensor_axes) != tensor.ndim:
64+
raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match")
65+
shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)}
66+
output = DataArray(tensor, dims=tuple(tensor_axes))
6867
# drop unwanted axes
69-
drop_axis_names = tuple(set(axes) - set(output_axes))
68+
drop_axis_names = tuple(set(tensor_axes) - set(output_axes))
7069
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names}
7170
output = output[drop_axes]
7271
# transpose to the desired axis order

0 commit comments

Comments
 (0)