Skip to content

Commit c4482fa

Browse files
committed
create Axis objects also from AxisDescr like objects
1 parent 0142135 commit c4482fa

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

src/bioimageio/core/axis.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from dataclasses import dataclass
44
from typing import Literal, Mapping, Optional, TypeVar, Union
55

6-
from typing_extensions import assert_never
7-
86
from bioimageio.spec.model import v0_5
7+
from typing_extensions import Protocol, assert_never, runtime_checkable
98

109

1110
def _guess_axis_type(a: str):
@@ -42,7 +41,16 @@ def _guess_axis_type(a: str):
4241
BatchSize = int
4342

4443
AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"]
45-
AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"]
44+
_AxisLikePlain = Union[str, AxisId, AxisLetter]
45+
46+
47+
@runtime_checkable
48+
class AxisDescrLike(Protocol):
49+
id: _AxisLikePlain
50+
type: Literal["batch", "channel", "index", "space", "time"]
51+
52+
53+
AxisLike = Union[_AxisLikePlain, "Axis", v0_5.AnyAxis, AxisDescrLike]
4654

4755

4856
@dataclass
@@ -60,14 +68,22 @@ def __post_init__(self):
6068
def create(cls, axis: AxisLike) -> Axis:
6169
if isinstance(axis, cls):
6270
return axis
63-
elif isinstance(axis, Axis):
64-
return Axis(id=axis.id, type=axis.type)
65-
elif isinstance(axis, v0_5.AxisBase):
66-
return Axis(id=AxisId(axis.id), type=axis.type)
67-
elif isinstance(axis, str):
68-
return Axis(id=AxisId(axis), type=_guess_axis_type(axis))
71+
72+
if isinstance(axis, (AxisId, str)):
73+
axis_id = axis
74+
axis_type = _guess_axis_type(str(axis))
6975
else:
70-
assert_never(axis)
76+
if hasattr(axis, "type"):
77+
axis_type = axis.type
78+
else:
79+
axis_type = _guess_axis_type(str(axis))
80+
81+
if hasattr(axis, "id"):
82+
axis_id = axis.id
83+
else:
84+
axis_id = axis
85+
86+
return Axis(id=AxisId(axis_id), type=axis_type)
7187

7288

7389
@dataclass
@@ -81,7 +97,7 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI
8197

8298
axis_base = super().create(axis)
8399
if maybe_singleton is None:
84-
if isinstance(axis, (Axis, str)):
100+
if not isinstance(axis, v0_5.AxisBase):
85101
maybe_singleton = True
86102
else:
87103
if axis.size is None:

src/bioimageio/core/tensor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919

2020
import numpy as np
2121
import xarray as xr
22+
from bioimageio.spec.model import v0_5
2223
from loguru import logger
2324
from numpy.typing import DTypeLike, NDArray
2425
from typing_extensions import Self, assert_never
2526

26-
from bioimageio.spec.model import v0_5
27-
2827
from ._magic_tensor_ops import MagicTensorOpsMixin
29-
from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
28+
from .axis import Axis, AxisDescrLike, AxisId, AxisInfo, AxisLike, PerAxis
3029
from .common import (
3130
CropWhere,
3231
DTypeStr,
@@ -187,10 +186,12 @@ def from_numpy(
187186

188187
if dims is None:
189188
return cls._interprete_array_wo_known_axes(array)
190-
elif isinstance(dims, (str, Axis, v0_5.AxisBase)):
191-
dims = [dims]
189+
elif isinstance(dims, (AxisId, AxisDescrLike)):
190+
dim_seq = [dims]
191+
else:
192+
dim_seq = list(dims)
192193

193-
axis_infos = [AxisInfo.create(a) for a in dims]
194+
axis_infos = [AxisInfo.create(a) for a in dim_seq]
194195
original_shape = tuple(array.shape)
195196

196197
successful_view = _get_array_view(array, axis_infos)

0 commit comments

Comments
 (0)