1- from typing import Any , ClassVar , Dict , FrozenSet , List , Literal , NewType , Optional , Set , Tuple , Union
1+ import collections .abc
2+ from typing import Any , ClassVar , Dict , FrozenSet , List , Literal , NewType , Optional , Sequence , Set , Tuple , Union
23
34from annotated_types import Ge , Gt , Interval , MaxLen , MinLen , Predicate
45from pydantic import (
@@ -204,7 +205,6 @@ class BatchAxis(AxisBase, frozen=True):
204205 otherwise (the default) it may be chosen arbitrarily depending on available memory"""
205206
206207
207- CHANNEL_NAMES_PLACEHOLDER = ("channel1" , "channel2" , "etc" )
208208ChannelName = Annotated [IdentifierStr , StringConstraints (min_length = 3 , max_length = 16 , pattern = r"^.*\{i\}.*$" )]
209209
210210
@@ -214,19 +214,22 @@ class ChannelAxis(AxisBase, frozen=True):
214214 channel_names : Union [Tuple [ChannelName , ...], ChannelName ] = "channel{i}"
215215 size : Union [Annotated [int , Gt (0 )], SizeReference , Literal ["#channel_names" ]] = "#channel_names"
216216
217- def model_post_init (self , __context : Any ):
218- self .model_config ["frozen" ] = False
219- if self .size == "#channel_names" :
220- self .size = len (self .channel_names ) # type: ignore
221- self .__pydantic_fields_set__ .remove ("size" )
217+ @model_validator (mode = "before" )
218+ @classmethod
219+ def set_size_or_channel_names (cls , data : Dict [str , Any ]):
220+ channel_names : Union [Any , Sequence [Any ]] = data .get ("channel_names" , "channel{i}" )
221+ size = data .get ("size" , "#channel_names" )
222+ if (
223+ size == "#channel_names"
224+ and not isinstance (channel_names , str )
225+ and isinstance (channel_names , collections .abc .Sequence )
226+ ):
227+ data ["size" ] = len (channel_names )
222228
223- if self .channel_names == CHANNEL_NAMES_PLACEHOLDER :
224- assert isinstance (self .size , int )
225- self .channel_names = tuple (f"channel{ i } " for i in range (1 , self .size + 1 )) # type: ignore
226- self .__pydantic_fields_set__ .remove ("channel_names" )
229+ if isinstance (channel_names , str ) and "{i}" in channel_names and isinstance (size , int ):
230+ data ["channel_names" ] = tuple (channel_names .format (i = i ) for i in range (1 , size + 1 ))
227231
228- self .model_config ["frozen" ] = True
229- return super ().model_post_init (__context )
232+ return data
230233
231234 @model_validator (mode = "after" )
232235 def validate_size_is_known (self ):
0 commit comments