Skip to content

Commit 2d5768e

Browse files
committed
fix generic channel_names in ChannelAxis
1 parent 8a441de commit 2d5768e

File tree

1 file changed

+51
-24
lines changed

1 file changed

+51
-24
lines changed

bioimageio/spec/model/v0_5.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
from bioimageio.spec._internal.types import DeprecatedLicenseId as DeprecatedLicenseId
2222
from bioimageio.spec._internal.types import FileSource as FileSource
2323
from bioimageio.spec._internal.types import Identifier as Identifier
24-
from bioimageio.spec._internal.types import IdentifierStr, LowerCaseIdentifierStr
2524
from bioimageio.spec._internal.types import LicenseId as LicenseId
25+
from bioimageio.spec._internal.types import LowerCaseIdentifierStr
2626
from bioimageio.spec._internal.types import NotEmpty as NotEmpty
2727
from bioimageio.spec._internal.types import RdfContent as RdfContent
2828
from bioimageio.spec._internal.types import RelativeFilePath as RelativeFilePath
2929
from bioimageio.spec._internal.types import Sha256 as Sha256
3030
from bioimageio.spec._internal.types import Unit as Unit
3131
from bioimageio.spec._internal.types import Version as Version
3232
from bioimageio.spec._internal.types.field_validation import AfterValidator
33-
from bioimageio.spec._internal.validation_context import InternalValidationContext
33+
from bioimageio.spec._internal.validation_context import InternalValidationContext, get_internal_validation_context
3434
from bioimageio.spec.dataset.v0_3 import Dataset as Dataset
3535
from bioimageio.spec.dataset.v0_3 import LinkedDataset as LinkedDataset
3636
from bioimageio.spec.generic.v0_3 import Attachment as Attachment
@@ -185,6 +185,20 @@ class AxisBase(NodeWithExplicitlySetFields, frozen=True):
185185
id: AxisId
186186
"""An axis id unique across all axes of one tensor."""
187187

188+
@model_validator(mode="before")
189+
@classmethod
190+
def convert_name_to_id(cls, data: Dict[str, Any], info: ValidationInfo):
191+
context = get_internal_validation_context(info.context)
192+
if (
193+
"name" in data
194+
and "id" not in data
195+
and "original_format" in context
196+
and context["original_format"].release[:2] == (0, 4)
197+
):
198+
data["id"] = data.pop("name")
199+
200+
return data
201+
188202
description: Annotated[str, MaxLen(128)] = ""
189203

190204
__hash__ = NodeWithExplicitlySetFields.__hash__
@@ -205,20 +219,23 @@ class BatchAxis(AxisBase, frozen=True):
205219
otherwise (the default) it may be chosen arbitrarily depending on available memory"""
206220

207221

208-
ChannelName = Annotated[IdentifierStr, StringConstraints(min_length=3, max_length=16, pattern=r"^.*\{i\}.*$")]
222+
GenericChannelName = Annotated[str, StringConstraints(min_length=3, max_length=16, pattern=r"^.*\{i\}.*$")]
209223

210224

211225
class ChannelAxis(AxisBase, frozen=True):
212226
type: Literal["channel"] = "channel"
213227
id: AxisId = AxisId("channel")
214-
channel_names: Union[Tuple[ChannelName, ...], ChannelName] = "channel{i}"
215-
size: Union[Annotated[int, Gt(0)], SizeReference, Literal["#channel_names"]] = "#channel_names"
228+
channel_names: Union[Tuple[Identifier, ...], Identifier, GenericChannelName] = "channel{i}"
229+
size: Union[Annotated[int, Gt(0)], SizeReference] = "#channel_names" # type: ignore
216230

217231
@model_validator(mode="before")
218232
@classmethod
219233
def set_size_or_channel_names(cls, data: Dict[str, Any]):
220234
channel_names: Union[Any, Sequence[Any]] = data.get("channel_names", "channel{i}")
221235
size = data.get("size", "#channel_names")
236+
if size == "#channel_names" and channel_names == "channel{i}":
237+
raise ValueError("Channel dimension has unknown size. Please specify `size` or `channel_names`.")
238+
222239
if (
223240
size == "#channel_names"
224241
and not isinstance(channel_names, str)
@@ -231,13 +248,6 @@ def set_size_or_channel_names(cls, data: Dict[str, Any]):
231248

232249
return data
233250

234-
@model_validator(mode="after")
235-
def validate_size_is_known(self):
236-
if self.size == "#channel_names":
237-
raise ValueError("Channel dimension has unknown size. Please specify `size` or `channel_names`.")
238-
239-
return self
240-
241251

242252
class IndexTimeSpaceAxisBase(AxisBase, frozen=True):
243253
size: Annotated[
@@ -948,14 +958,6 @@ def check_entries(self, info: ValidationInfo) -> Self:
948958
return self
949959

950960

951-
# def get_default_partial_inputs():
952-
# return (
953-
# InputTensor(axes=(BatchAxis(),), test_tensor=HttpUrl("https://example.com/test.npy")).model_dump(
954-
# exclude_unset=False, exclude={"axes", "test_tensor"}
955-
# ),
956-
# )
957-
958-
959961
class Model(GenericBaseNoSource, frozen=True, title="bioimage.io model specification"):
960962
"""Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
961963
These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
@@ -1070,8 +1072,21 @@ def _validate_axis(
10701072
f"Self-referencing not allowed for {field_name}[{i}].axes[{a}].size.reference: "
10711073
f"{axis.size.reference}"
10721074
)
1073-
if axis.type == "channel" and valid_independent_refs[axis.size.reference][1].type != "channel":
1074-
raise ValueError("A channel axis' size may only reference another fixed size channel axis.")
1075+
if axis.type == "channel":
1076+
if valid_independent_refs[axis.size.reference][1].type != "channel":
1077+
raise ValueError("A channel axis' size may only reference another fixed size channel axis.")
1078+
if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
1079+
ref_size = valid_independent_refs[axis.size.reference][2]
1080+
assert isinstance(
1081+
ref_size, int
1082+
), "channel axis ref (another channel axis) has to specify fixed size"
1083+
generated_channel_names = tuple(
1084+
Identifier(axis.channel_names.format(i=i)) for i in range(1, ref_size + 1)
1085+
)
1086+
axis.model_config["frozen"] = False
1087+
axis.channel_names = generated_channel_names # type: ignore
1088+
axis.model_config["frozen"] = True
1089+
10751090
if (ax_unit := getattr(axis, "unit", None)) != (
10761091
ref_unit := getattr(valid_independent_refs[axis.size.reference][1], "unit", None)
10771092
):
@@ -1090,8 +1105,20 @@ def _validate_axis(
10901105
raise ValueError(f"Invalid tensor axis reference at {field_name}[{i}].axes[{a}].size: {axis.size}.")
10911106
if axis.size in (axis.id, f"{tensor_id}.{axis.id}"):
10921107
raise ValueError(f"Self-referencing not allowed for {field_name}[{i}].axes[{a}].size: {axis.size}.")
1093-
if axis.type == "channel" and valid_independent_refs[axis.size][1].type != "channel":
1094-
raise ValueError("A channel axis' size may only reference another fixed size channel axis.")
1108+
if axis.type == "channel":
1109+
if valid_independent_refs[axis.size][1].type != "channel":
1110+
raise ValueError("A channel axis' size may only reference another fixed size channel axis.")
1111+
if isinstance(axis.channel_names, str) and "{i}" in axis.channel_names:
1112+
ref_size = valid_independent_refs[axis.size][2]
1113+
assert isinstance(
1114+
ref_size, int
1115+
), "channel axis ref (another channel axis) has to specify fixed size"
1116+
generated_channel_names = tuple(
1117+
Identifier(axis.channel_names.format(i=i)) for i in range(1, ref_size + 1)
1118+
)
1119+
axis.model_config["frozen"] = False
1120+
axis.channel_names = generated_channel_names # type: ignore
1121+
axis.model_config["frozen"] = True
10951122

10961123
license: Annotated[
10971124
Union[LicenseId, DeprecatedLicenseId],

0 commit comments

Comments
 (0)