Skip to content

Commit b82ecac

Browse files
committed
fix ChannelAxis defaults
1 parent 4792992 commit b82ecac

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

bioimageio/spec/model/v0_5.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, ClassVar, Dict, FrozenSet, List, Literal, NewType, Optional, Set, Tuple, Union
1+
import collections.abc
2+
from typing import Any, ClassVar, Dict, FrozenSet, List, Literal, NewType, Optional, Sequence, Set, Tuple, Union
23

34
from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
45
from pydantic import (
@@ -204,7 +205,6 @@ class BatchAxis(AxisBase, frozen=True):
204205
otherwise (the default) it may be chosen arbitrarily depending on available memory"""
205206

206207

207-
CHANNEL_NAMES_PLACEHOLDER = ("channel1", "channel2", "etc")
208208
ChannelName = Annotated[IdentifierStr, StringConstraints(min_length=3, max_length=16, pattern=r"^.*\{i\}.*$")]
209209

210210

@@ -214,19 +214,22 @@ class ChannelAxis(AxisBase, frozen=True):
214214
channel_names: Union[Tuple[ChannelName, ...], ChannelName] = "channel{i}"
215215
size: Union[Annotated[int, Gt(0)], SizeReference, Literal["#channel_names"]] = "#channel_names"
216216

217-
def model_post_init(self, __context: Any):
218-
self.model_config["frozen"] = False
219-
if self.size == "#channel_names":
220-
self.size = len(self.channel_names) # type: ignore
221-
self.__pydantic_fields_set__.remove("size")
217+
@model_validator(mode="before")
218+
@classmethod
219+
def set_size_or_channel_names(cls, data: Dict[str, Any]):
220+
channel_names: Union[Any, Sequence[Any]] = data.get("channel_names", "channel{i}")
221+
size = data.get("size", "#channel_names")
222+
if (
223+
size == "#channel_names"
224+
and not isinstance(channel_names, str)
225+
and isinstance(channel_names, collections.abc.Sequence)
226+
):
227+
data["size"] = len(channel_names)
222228

223-
if self.channel_names == CHANNEL_NAMES_PLACEHOLDER:
224-
assert isinstance(self.size, int)
225-
self.channel_names = tuple(f"channel{i}" for i in range(1, self.size + 1)) # type: ignore
226-
self.__pydantic_fields_set__.remove("channel_names")
229+
if isinstance(channel_names, str) and "{i}" in channel_names and isinstance(size, int):
230+
data["channel_names"] = tuple(channel_names.format(i=i) for i in range(1, size + 1))
227231

228-
self.model_config["frozen"] = True
229-
return super().model_post_init(__context)
232+
return data
230233

231234
@model_validator(mode="after")
232235
def validate_size_is_known(self):

0 commit comments

Comments
 (0)