33from dataclasses import dataclass
44from typing import Literal , Mapping , Optional , TypeVar , Union
55
6- from typing_extensions import assert_never
7-
86from bioimageio .spec .model import v0_5
7+ from typing_extensions import Protocol , assert_never , runtime_checkable
98
109
1110def _guess_axis_type (a : str ):
@@ -42,7 +41,16 @@ def _guess_axis_type(a: str):
4241BatchSize = int
4342
4443AxisLetter = Literal ["b" , "i" , "t" , "c" , "z" , "y" , "x" ]
45- AxisLike = Union [AxisId , AxisLetter , v0_5 .AnyAxis , "Axis" ]
44+ _AxisLikePlain = Union [str , AxisId , AxisLetter ]
45+
46+
47+ @runtime_checkable
48+ class AxisDescrLike (Protocol ):
49+ id : _AxisLikePlain
50+ type : Literal ["batch" , "channel" , "index" , "space" , "time" ]
51+
52+
53+ AxisLike = Union [_AxisLikePlain , "Axis" , v0_5 .AnyAxis , AxisDescrLike ]
4654
4755
4856@dataclass
@@ -60,14 +68,22 @@ def __post_init__(self):
6068 def create (cls , axis : AxisLike ) -> Axis :
6169 if isinstance (axis , cls ):
6270 return axis
63- elif isinstance (axis , Axis ):
64- return Axis (id = axis .id , type = axis .type )
65- elif isinstance (axis , v0_5 .AxisBase ):
66- return Axis (id = AxisId (axis .id ), type = axis .type )
67- elif isinstance (axis , str ):
68- return Axis (id = AxisId (axis ), type = _guess_axis_type (axis ))
71+
72+ if isinstance (axis , (AxisId , str )):
73+ axis_id = axis
74+ axis_type = _guess_axis_type (str (axis ))
6975 else :
70- assert_never (axis )
76+ if hasattr (axis , "type" ):
77+ axis_type = axis .type
78+ else :
79+ axis_type = _guess_axis_type (str (axis ))
80+
81+ if hasattr (axis , "id" ):
82+ axis_id = axis .id
83+ else :
84+ axis_id = axis
85+
86+ return Axis (id = AxisId (axis_id ), type = axis_type )
7187
7288
7389@dataclass
@@ -81,7 +97,7 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI
8197
8298 axis_base = super ().create (axis )
8399 if maybe_singleton is None :
84- if isinstance (axis , ( Axis , str ) ):
100+ if not isinstance (axis , v0_5 . AxisBase ):
85101 maybe_singleton = True
86102 else :
87103 if axis .size is None :
0 commit comments