@@ -238,6 +238,7 @@ def is_long_tensor(self) -> bool:
238238
239239class _FileFeatureFlags (enum .IntFlag ):
240240 encrypted = enum .auto ()
241+ versioned_headers = enum .auto ()
241242
242243
243244@dataclasses .dataclass
@@ -310,7 +311,7 @@ def from_io(
310311 feature_flag_bytes , "little" , signed = False
311312 )
312313 feature_flags = _FileFeatureFlags (feature_flag_int )
313- if not (0 <= feature_flags <= max (_FileFeatureFlags )):
314+ if not (0 <= feature_flags < ( max (_FileFeatureFlags ) << 1 )):
314315 raise ValueError (
315316 f"Unsupported feature flags: { _FileFeatureFlags !r} "
316317 )
@@ -1838,8 +1839,17 @@ def __init__(
18381839 "Tensor is encrypted, but decryption was not requested"
18391840 )
18401841 self ._has_versioned_metadata : bool = (
1841- version_number >= LONG_TENSOR_TENSORIZER_VERSION
1842+ _FileFeatureFlags . versioned_headers in self . _file_flags
18421843 )
1844+ if (
1845+ self ._has_versioned_metadata
1846+ and version_number < LONG_TENSOR_TENSORIZER_VERSION
1847+ ):
1848+ raise ValueError (
1849+ "Invalid feature flag present in file header for a file"
1850+ f" with version { version_number :d} "
1851+ f" (flags: { self ._file_flags !s} )"
1852+ )
18431853
18441854 # The total size of the file.
18451855 # WARNING: this is not accurate. This field isn't used in the
@@ -3904,6 +3914,10 @@ def _synchronize_metadata(self):
39043914 if pos + total_length > self ._metadata_end :
39053915 raise RuntimeError ("Metadata overflow" )
39063916 self ._pwrite_bulk (buffers , pos , total_length )
3917+ if self ._metadata_handler .is_versioned :
3918+ self ._file_header .feature_flags |= (
3919+ _FileFeatureFlags .versioned_headers
3920+ )
39073921
39083922 def _pwrite_bulk (
39093923 self , buffers : Sequence [bytes ], offset : int , expected_length : int
@@ -3937,7 +3951,7 @@ class _MetadataHandler:
39373951 metadata entries need to be rewritten because of other entries written
39383952 in subsequent batches.
39393953
3940- Internally, it functions like a state machine .
3954+ Internally, it has two states .
39413955
39423956 In its initial state, it tracks pending metadata entries to be written,
39433957 as well as past metadata entries that were already written. It stays
@@ -3948,55 +3962,32 @@ class _MetadataHandler:
39483962 Once it is given any tensor using a metadata scheme newer than V1,
39493963 it transitions to its second state. In this state, all
39503964 previously-written metadata entries are moved back into a pending state,
3951- and version tags are prepended to every entry. It stays in this state
3952- until the next write operation (i.e. the next call to ``commit()``),
3953- after which it moves into its final state.
3954-
3955- In its final state, no more history is saved for previously-written
3956- metadata entries, as historical entries will at this point never again
3957- need to be rewritten. Version tags continue to be prepended
3958- to new entries. It remains in this state forever.
3965+ and version tags are prepended to every entry. No more history is saved
3966+ for newly-written metadata entries, as historical entries will at
3967+ this point never again need to be rewritten.
39593968 """
39603969
3961- __slots__ = ("pending" , "past" , "version" , "_pos" , "_state " )
3970+ __slots__ = ("pending" , "past" , "version" , "_pos" , "_is_updated " )
39623971 pending : list
39633972 past : list
39643973 version : int
39653974 _pos : int
39663975
3967- class _MetadataHandlerState (enum .Enum ):
3968- TRACKING_PAST = 1
3969- STAGING_PAST = 2
3970- NO_PAST = 3
3971-
3972- _state : _MetadataHandlerState
39733976 V1_TAG : ClassVar [bytes ] = b"\x01 \x00 \x00 \x00 "
39743977
3975- @property
3976- def _tracking_past (self ) -> bool :
3977- return self ._state is self ._MetadataHandlerState .TRACKING_PAST
3978-
3979- @property
3980- def _staging_past (self ) -> bool :
3981- return self ._state is self ._MetadataHandlerState .STAGING_PAST
3982-
3983- @property
3984- def _no_past (self ) -> bool :
3985- return self ._state is self ._MetadataHandlerState .NO_PAST
3986-
39873978 def __init__ (self ):
39883979 self .pending = []
39893980 self .past = []
39903981 self .version = 1
39913982 self ._pos = 0
3992- self ._state = self . _MetadataHandlerState . TRACKING_PAST
3983+ self ._is_updated = False
39933984
39943985 def submit (self , metadata : bytes , version : int ):
39953986 if version > self .version :
39963987 if self .version == 1 :
39973988 self ._update ()
39983989 self .version = version
3999- if not self ._tracking_past :
3990+ if self ._is_updated :
40003991 self .pending .append (version .to_bytes (4 , byteorder = "little" ))
40013992 self .pending .append (metadata )
40023993
@@ -4005,10 +3996,8 @@ def commit(self):
40053996 # Successive write positions are not a monotone sequence
40063997 pending = self .pending
40073998 self .pending = []
4008- if self ._tracking_past :
3999+ if not self ._is_updated :
40094000 self .past .extend (pending )
4010- elif self ._staging_past :
4011- self ._state = self ._MetadataHandlerState .NO_PAST
40124001 total_length = sum (len (d ) for d in pending )
40134002 pos = self ._pos
40144003 self ._pos += total_length
@@ -4017,7 +4006,7 @@ def commit(self):
40174006 def _update (self ):
40184007 # This is only called the one time that self.version is updated
40194008 # up from 1, so this should always be in the initial state
4020- assert self ._tracking_past
4009+ assert not self ._is_updated
40214010 # At the time this is called, everything in self.past and
40224011 # self.pending must be version 1, so no complicated checking is
40234012 # needed to figure out what needs to be tagged with a v1 tag
@@ -4029,7 +4018,11 @@ def _update(self):
40294018 self .pending = pending
40304019 self .past .clear ()
40314020 self ._pos = 0
4032- self ._state = self ._MetadataHandlerState .STAGING_PAST
4021+ self ._is_updated = True
4022+
4023+ @property
4024+ def is_versioned (self ) -> bool :
4025+ return self .version > 1
40334026
40344027 def write_tensor (
40354028 self ,
0 commit comments