|
14 | 14 |
|
15 | 15 |
|
16 | 16 | def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None): |
17 | | - """Transform input image to adhere to the axes spec defined by a bioimage.io model. |
| 17 | + """Transform input image into output tensor with desired axes. |
18 | 18 |
|
19 | 19 | Args: |
20 | 20 | image: the input image |
@@ -51,22 +51,21 @@ def _drop_axis_default(axis_name, axis_len): |
51 | 51 | return axis_len // 2 if axis_name in "zyx" else 0 |
52 | 52 |
|
53 | 53 |
|
54 | | -def transform_output_image(tensor: np.ndarray, spec, output_axes: str, drop_function=_drop_axis_default): |
55 | | - """Transform output tensor to image with the desired axes. |
| 54 | +def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default): |
| 55 | + """Transform output tensor into image with desired axes. |
56 | 56 |
|
57 | 57 | Args: |
58 | | - tensor the output tensor |
59 | | - spec: bioimageio model spec |
| 58 | + tensor: the output tensor |
| 59 | + tensor_axes: bioimageio model spec |
60 | 60 | output_axes: the desired output axes |
61 | 61 | drop_function: function that determines how to drop unwanted axes |
62 | 62 | """ |
63 | | - axes = spec["axes"] |
64 | | - shape = {ax_name: sh for ax_name, sh in zip(axes, tensor.shape)} |
65 | | - if len(axes) != tensor.ndim: |
66 | | - raise ValueError(f"Number of axes {len(axes)} and dimension of tensor {tensor.ndim} don't match") |
67 | | - output = DataArray(tensor, dims=tuple(axes)) |
| 63 | + if len(tensor_axes) != tensor.ndim: |
| 64 | + raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match") |
| 65 | + shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)} |
| 66 | + output = DataArray(tensor, dims=tuple(tensor_axes)) |
68 | 67 | # drop unwanted axes |
69 | | - drop_axis_names = tuple(set(axes) - set(output_axes)) |
| 68 | + drop_axis_names = tuple(set(tensor_axes) - set(output_axes)) |
70 | 69 | drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names} |
71 | 70 | output = output[drop_axes] |
72 | 71 | # transpose to the desired axis order |
|
0 commit comments