diff --git a/CHANGELOG.md b/CHANGELOG.md index 28b29ee..c0795b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Tensors with long dimensions (≥ `2 ** 32` elements in a single dimension) + can now be serialized and deserialized + +### Fixed + +- `tensorizer.torch_compat` can now serialize and deserialize tensors that have + storages with sizes ≥ `2 ** 32` + ## [2.11.1] - 2025-08-05 ### Fixed @@ -492,6 +504,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `get_gpu_name` - `no_init_or_tensor` +[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.11.1...HEAD [2.11.1]: https://github.com/coreweave/tensorizer/compare/v2.11.0...v2.11.1 [2.11.0]: https://github.com/coreweave/tensorizer/compare/v2.10.1...v2.11.0 [2.10.1]: https://github.com/coreweave/tensorizer/compare/v2.10.0...v2.10.1 diff --git a/tensorizer/_version.py b/tensorizer/_version.py index 200c236..00eeb58 100644 --- a/tensorizer/_version.py +++ b/tensorizer/_version.py @@ -1 +1 @@ -__version__ = "2.11.1" +__version__ = "2.12.0a0" diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 124f562..9da33d6 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -158,7 +158,11 @@ class TensorType(IntEnum): # Current version -TENSORIZER_VERSION = 4 +TENSORIZER_VERSION = 5 + +# To serialize tensors with individual dimensions longer than 2^32 elements, +# data version 5 is required. +LONG_TENSOR_TENSORIZER_VERSION = 5 # To serialize meta tensors into metadata-only tensors # that deserialize back into zeroed-out buffers, data version 4 is required. @@ -197,6 +201,7 @@ class TensorHash: @dataclasses.dataclass(order=True) class TensorEntry: __slots__ = ( + "metadata_version", "name", "type", "dtype", @@ -207,6 +212,7 @@ class TensorEntry: "hashes", "header_hashes", ) + metadata_version: int name: _TensorPath type: TensorType dtype: str @@ -225,6 +231,10 @@ def deserialized_length(self): num_elements: int = numpy.prod(self.shape) return element_size * num_elements + @property + def is_long_tensor(self) -> bool: + return self.metadata_version >= LONG_TENSOR_TENSORIZER_VERSION + class _FileFeatureFlags(enum.IntFlag): encrypted = enum.auto() @@ -313,6 +323,7 @@ class _TensorHeaderSerializer: # other fields are calculated per-instance buffer: bytearray size: int + has_long_dimensions: bool start_segment: ClassVar[struct.Struct] = struct.Struct( "<" # Little-endian @@ -330,7 +341,7 @@ class _TensorHeaderSerializer: "B" # Tensor dtype length "{dtype_len:d}s" # Tensor dtype UTF-8 bytes "B" # Tensor shape length - "{shape_len:d}I" # Tensor shape I array + "{shape_len:d}{shape_type:.1s}" # Tensor shape int array ) variable_length_segment: struct.Struct variable_length_offset: ClassVar[int] = start_segment.size @@ -383,7 +394,7 @@ class _TensorHeaderSerializer: "B" # Dtype length "{dtype_len:d}s" # Dtype "B" # Shape length - "{shape_len:d}I" # Shape + "{shape_len:d}{shape_type:.1s}" # Shape "Q" # Header start (relative to the file) "Q" # Tensor data start (relative to the file) "Q" # Tensor length @@ -409,6 +420,11 @@ def __init__( # NB: shape_len is the number of dimensions, # not the encoded byte length shape_len = len(shape) + short_dimension_limit: int = 1 << 32 + self.has_long_dimensions: bool = any( + dim >= short_dimension_limit for dim in shape + ) + shape_type: str = "Q" if self.has_long_dimensions else "I" self.crypt_info = crypt_info if crypt_info is None: crypt_info_len = 0 @@ -419,6 +435,7 @@ def __init__( name_len=name_len, dtype_len=dtype_len, shape_len=shape_len, + shape_type=shape_type, ) ) crc32_len = sha256_len = self.hash_count = 0 @@ -498,6 +515,7 @@ def __init__( name_len=name_len, dtype_len=dtype_len, shape_len=shape_len, + shape_type=shape_type, ) ) @@ -587,12 +605,23 @@ class _TensorHeaderDeserializer: " Optional["_TensorHeaderDeserializer"]: # We read the entire header into memory rather than reading # it piecewise to avoid the overhead of many small reads, @@ -617,7 +647,10 @@ def from_io( with memoryview(buffer) as mv: reader.readinto(mv[offset:]) return cls( - buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info + buffer, + zero_hashes=zero_hashes, + check_crypt_info=check_crypt_info, + long_shape_tensors=long_shape_tensors, ) def __init__( @@ -625,6 +658,7 @@ def __init__( buffer: bytearray, zero_hashes: bool = True, check_crypt_info: bool = False, + long_shape_tensors: frozenset = frozenset(), ): self.buffer = buffer offset = self.header_len_segment.size @@ -650,7 +684,11 @@ def __init__( self.dtype: str = str(dtype_slice, "utf-8") # Read the shape. - self.shape, offset = self.read_shape(buffer, offset) + long_shapes: bool = self.name in long_shape_tensors + read_shape_func: callable = ( + self.read_shape_long if long_shapes else self.read_shape_short + ) + self.shape, offset = read_shape_func(buffer, offset) # Read our hashes in. hashes_slice, offset = self.read_hash_block(buffer, offset) @@ -765,10 +803,20 @@ def _zero_hashes(b: memoryview) -> None: class _MetadataDeserializer(dict): _total_len_segment: ClassVar[struct.Struct] = struct.Struct(" Tuple["_MetadataDeserializer", _TensorPathRegistry, bytes]: raw = reader.read(cls._total_len_segment.size) total_len: int = cls._total_len_segment.unpack(raw)[0] @@ -787,24 +839,59 @@ def from_io( else: encoded_metadata: bytes = reader.read(total_len) raw += encoded_metadata - return cls.from_buffer(encoded_metadata, count) + (raw,) + return cls.from_buffer( + encoded_metadata, count, versioned, accepted_versions + ) + (raw,) @classmethod def from_buffer( - cls, buffer: bytes, count: int + cls, + buffer: bytes, + count: int, + versioned: bool, + accepted_versions: Optional[Iterable[int]], ) -> Tuple["_MetadataDeserializer", _TensorPathRegistry]: offset = 0 entries = cls() registry = _TensorPathRegistry() for i in range(count): - entry, offset = cls._read_entry(buffer, offset, registry) + entry, offset = cls._read_entry( + buffer, offset, registry, versioned, accepted_versions + ) entries[entry.name] = entry return entries, registry @classmethod def _read_entry( - cls, buffer: bytes, offset: int, registry: _TensorPathRegistry + cls, + buffer: bytes, + offset: int, + registry: _TensorPathRegistry, + versioned: bool, + accepted_versions: Optional[Iterable[int]], ) -> Tuple[TensorEntry, int]: + long_shapes: bool = False + version: int = NON_OPAQUE_TENSORIZER_VERSION + if versioned: + version = cls._version_segment.unpack_from(buffer, offset)[0] + offset += cls._version_segment.size + if version not in accepted_versions: + # This shouldn't come up in a valid file, because a newer + # metadata version implies a newer file version, so this ought + # to be rejected by the file-level version check first. + # Nonetheless, for an invalid file violating this assumption, + # give a descriptive error message + accepted_versions_str: str = ", ".join( + map(str, sorted(set(accepted_versions))) + ) + message = ( + "Unsupported version: this data stream uses tensorizer" + f" metadata version {version}, which is not supported" + " in this release of tensorizer." + f"\nSupported metadata versions: {accepted_versions_str}" + ) + raise ValueError(message) + long_shapes = version >= LONG_TENSOR_TENSORIZER_VERSION # Read the name. name_slice, offset = cls._read_name(buffer, offset) with name_slice: @@ -820,7 +907,10 @@ def _read_entry( dtype: str = str(dtype_slice, "utf-8") # Read the shape. - shape, offset = cls._read_shape(buffer, offset) + read_shape_func: callable = ( + cls._read_shape_long if long_shapes else cls._read_shape_short + ) + shape, offset = read_shape_func(buffer, offset) ( header_offset, @@ -831,6 +921,7 @@ def _read_entry( return ( TensorEntry( + metadata_version=version, name=name, type=tensor_type, dtype=dtype, @@ -1702,6 +1793,7 @@ def __init__( OPAQUE_TENSORIZER_VERSION, ENCRYPTION_TENSORIZER_VERSION, META_TENSOR_TENSORIZER_VERSION, + LONG_TENSOR_TENSORIZER_VERSION, TENSORIZER_VERSION, ) encryption_ver: int = ENCRYPTION_TENSORIZER_VERSION @@ -1745,6 +1837,9 @@ def __init__( raise CryptographyError( "Tensor is encrypted, but decryption was not requested" ) + self._has_versioned_metadata: bool = ( + version_number >= LONG_TENSOR_TENSORIZER_VERSION + ) # The total size of the file. # WARNING: this is not accurate. This field isn't used in the @@ -1757,13 +1852,28 @@ def __init__( # This is a list of offsets into the file where the per-tensor data # is stored. self._metadata: Dict[_TensorPath, TensorEntry] + accepted_metadata_versions: Optional[Tuple[int, ...]] = ( + (NON_OPAQUE_TENSORIZER_VERSION, LONG_TENSOR_TENSORIZER_VERSION) + if self._has_versioned_metadata + else None + ) self._metadata, structure, self._metadata_raw = ( _MetadataDeserializer.from_io( - self._file, self._file_header.tensor_count + self._file, + self._file_header.tensor_count, + self._has_versioned_metadata, + accepted_metadata_versions, ) ) if not self._metadata: raise ValueError("Tensor index in the file is empty") + self._long_shape_tensors: typing.FrozenSet[_TensorPath] = frozenset( + { + path + for path, entry in self._metadata.items() + if entry.is_long_tensor + } + ) # filter_func is a test that determines the tensor names to read. # If filter_func is None, all tensors are read. if filter_func is not None: @@ -2961,6 +3071,7 @@ def _copy_thread( file_, zero_hashes=True, check_crypt_info=unsafe_self._has_crypt_info, + long_shape_tensors=unsafe_self._long_shape_tensors, ) if header is None: @@ -3545,8 +3656,8 @@ def __init__( self._metadata_loc = self._file.tell() self._write(bytes(metadata_size)) self._flush() - self._metadata_cur = self._metadata_loc self._metadata_end = self._metadata_loc + metadata_size + self._metadata_handler = self._MetadataHandler() @property def total_tensor_bytes(self): @@ -3739,6 +3850,9 @@ def _shutdown_thread_pools(self): thread_pool.shutdown(wait=False) def _synchronize_pools(self): + # Synchronizing metadata should happen at the same time as synchronizing + # pools, although it is not itself executed in a separate thread. + self._synchronize_metadata() for j in self._jobs: j.result(timeout=_TIMEOUT) self._jobs.clear() @@ -3784,6 +3898,139 @@ def _new_nonces(self, count: int) -> Tuple[bytes, ...]: self._used_nonces.update(nonces) return nonces + def _synchronize_metadata(self): + buffers, total_length, relative_pos = self._metadata_handler.commit() + pos: int = self._metadata_loc + relative_pos + if pos + total_length > self._metadata_end: + raise RuntimeError("Metadata overflow") + self._pwrite_bulk(buffers, pos, total_length) + + def _pwrite_bulk( + self, buffers: Sequence[bytes], offset: int, expected_length: int + ): + # This doesn't bother using os.pwritev because it's only called for + # small amounts of data here, though if larger buffers get involved, + # os.pwritev would be helpful + staged_length: int = 0 + with io.BytesIO() as staging_buffer: + for b in buffers: + staged_length += staging_buffer.write(b) + assert staged_length == expected_length + self._pwrite(staging_buffer.getvalue(), offset, expected_length) + + @dataclasses.dataclass(init=False) + class _MetadataHandler: + """ + Tracks and buffers metadata to be written to the file header. + + This class takes care of the logic around writing version tags + on metadata headers, as well as adding version tags to already-written + metadata headers if necessary. + + The interface to this class is two functions: + ``submit()`` and ``commit()``. To buffer a metadata entry to be written + later, call ``submit()``. When a batch of metadata entries are ready + to be written, call ``commit()`` to get a sequence of ``bytes`` objects + and the offset to which to write them. State persists across multiple + ``commit()`` calls. The writing offset typically advances on successive + calls to ``commit()``, but it may go backwards if previously-committed + metadata entries need to be rewritten because of other entries written + in subsequent batches. + + Internally, it functions like a state machine. + + In its initial state, it tracks pending metadata entries to be written, + as well as past metadata entries that were already written. It stays + in this state as long as the tensors being written all use the V1 + metadata scheme (i.e. the original metadata scheme from tensorizer + data versions 1 through 4). + + Once it is given any tensor using a metadata scheme newer than V1, + it transitions to its second state. In this state, all + previously-written metadata entries are moved back into a pending state, + and version tags are prepended to every entry. It stays in this state + until the next write operation (i.e. the next call to ``commit()``), + after which it moves into its final state. + + In its final state, no more history is saved for previously-written + metadata entries, as historical entries will at this point never again + need to be rewritten. Version tags continue to be prepended + to new entries. It remains in this state forever. + """ + + __slots__ = ("pending", "past", "version", "_pos", "_state") + pending: list + past: list + version: int + _pos: int + + class _MetadataHandlerState(enum.Enum): + TRACKING_PAST = 1 + STAGING_PAST = 2 + NO_PAST = 3 + + _state: _MetadataHandlerState + V1_TAG: ClassVar[bytes] = b"\x01\x00\x00\x00" + + @property + def _tracking_past(self) -> bool: + return self._state is self._MetadataHandlerState.TRACKING_PAST + + @property + def _staging_past(self) -> bool: + return self._state is self._MetadataHandlerState.STAGING_PAST + + @property + def _no_past(self) -> bool: + return self._state is self._MetadataHandlerState.NO_PAST + + def __init__(self): + self.pending = [] + self.past = [] + self.version = 1 + self._pos = 0 + self._state = self._MetadataHandlerState.TRACKING_PAST + + def submit(self, metadata: bytes, version: int): + if version > self.version: + if self.version == 1: + self._update() + self.version = version + if not self._tracking_past: + self.pending.append(version.to_bytes(4, byteorder="little")) + self.pending.append(metadata) + + def commit(self): + # Return a buffer array, total length, and relative write position + # Successive write positions are not a monotone sequence + pending = self.pending + self.pending = [] + if self._tracking_past: + self.past.extend(pending) + elif self._staging_past: + self._state = self._MetadataHandlerState.NO_PAST + total_length = sum(len(d) for d in pending) + pos = self._pos + self._pos += total_length + return pending, total_length, pos + + def _update(self): + # This is only called the one time that self.version is updated + # up from 1, so this should always be in the initial state + assert self._tracking_past + # At the time this is called, everything in self.past and + # self.pending must be version 1, so no complicated checking is + # needed to figure out what needs to be tagged with a v1 tag + pending = [] + v1_tag: typing.Final[bytes] = self.V1_TAG + for metadata in itertools.chain(self.past, self.pending): + pending.append(v1_tag) + pending.append(metadata) + self.pending = pending + self.past.clear() + self._pos = 0 + self._state = self._MetadataHandlerState.STAGING_PAST + def write_tensor( self, idx, @@ -3972,20 +4219,16 @@ def _write_tensor( # Add our tensor metadata to the index. metadata = header.metadata_entry - # Check for overflow - if self._metadata_cur + len(metadata) > self._metadata_end: - raise RuntimeError("Metadata overflow") - - metadata_pos = self._metadata_cur - metadata_len = len(metadata) - self._metadata_cur += metadata_len - - # This task is I/O-bound and has no prerequisites, - # so it goes into the regular writer pool. - def write_metadata(): - self._pwrite(metadata, metadata_pos, verify=metadata_len) - - self._jobs.append(self._writer_pool.submit(write_metadata)) + metadata_version: int + if header.has_long_dimensions: + metadata_version = LONG_TENSOR_TENSORIZER_VERSION + self._file_header.version_number = max( + self._file_header.version_number, + LONG_TENSOR_TENSORIZER_VERSION, + ) + else: + metadata_version = NON_OPAQUE_TENSORIZER_VERSION + self._metadata_handler.submit(metadata, metadata_version) # Calculate the hashes. diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8c6d657..ff50bdf 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -20,6 +20,8 @@ import torch +import tensorizer + os.environ["TOKENIZERS_PARALLELISM"] = ( "false" # avoids excessive warnings about forking after using a tokenizer ) @@ -268,6 +270,18 @@ def temporary_file(*args, **kwargs): class TestSerialization(unittest.TestCase): + @staticmethod + def get_version(deserializer: tensorizer.TensorDeserializer) -> int: + return deserializer._file_header.version_number + + @staticmethod + def free_cpu_ram() -> None: + gc.collect() + if empty_cache := getattr(torch._C, "_host_emptyCache", None): + # Clear up pinned memory held by PyTorch's caching allocator + empty_cache() + gc.collect() + def test_serialization(self): for device, method in itertools.product( ("cuda", "cpu"), @@ -292,6 +306,10 @@ def test_serialization(self): deserialized = TensorDeserializer( serialized_model, device="cpu" ) + self.assertEqual( + self.get_version(deserialized), + serialization.NON_OPAQUE_TENSORIZER_VERSION, + ) check_deserialized( self, deserialized, @@ -311,7 +329,7 @@ def test_large_unbuffered_tensor(self): num_elements: int = 36000 * 36000 bytes_required: int = num_elements * 4 assert bytes_required > 1 << 32 - gc.collect() + self.free_cpu_ram() free_mem = utils.CPUMemoryUsage.now().free working_space: int = 10 << 20 if free_mem < bytes_required + working_space: @@ -353,6 +371,111 @@ def test_large_unbuffered_tensor(self): del deserializer, tensor, deserialized_tensor gc.collect() + def test_long_dimensions(self): + # Test serializing tensors with individual dimensions longer than + # 2^32 elements, only supported in tensorizer data version 5 and up + # (corresponding to tensorizer code version 2.12 and up) + # This test takes a lot of RAM, so free up as much as possible first + self.free_cpu_ram() + + tensor_length: int = (1 << 32) + 128 + free_mem = utils.CPUMemoryUsage.now().free + working_space: int = 10 << 20 + if free_mem < tensor_length + working_space: + self.skipTest( + reason="Insufficient RAM to test long dimension serialization" + ) + plentiful_ram: bool = free_mem > (tensor_length + working_space) * 2 + long_tensor = torch.empty( + (tensor_length,), dtype=torch.int8, device="cpu" + ) + # Insert some arbitrary fixed values for an easy integrity check later + long_tensor[0] = 62 + long_tensor[tensor_length - 64] = 72 + long_tensor[-1] = 82 + + def validate_long_tensor(t: torch.Tensor) -> None: + self.assertEqual(t[0], 62) + self.assertEqual(t[tensor_length - 64], 72) + self.assertEqual(t[-1], 82) + + def rand_tensor() -> torch.Tensor: + return torch.rand((16, 16), dtype=torch.float, device="cpu") + + state_dict: "typing.TypeAlias" = typing.Dict[str, torch.Tensor] + + # First, serialize three normal tensors + # These should have their headers rewritten later, + # even if the long-dimension tensor is written separately + sd1: state_dict = {i: rand_tensor() for i in "123"} + + # Then, try serializing a very long tensor betwixt two normal tensors + # This ensures that metadata buffering is working properly + sd2: state_dict = { + "4": rand_tensor(), + "5": long_tensor, + "6": rand_tensor(), + } + del long_tensor + + # Then, in a third write operation, do that last part again + # This ensures that the internal state is still usable after + # the previous write operation has ended + sd3: state_dict = dict(zip("789", sd2.values())) + + # Finally, in a fourth write, serialize three normal tensors + # This ensures that even if a later write operation doesn't contain + # a tensor with long dimensions, it will continue writing with the + # newer format anyway because previous ones did + sd4: state_dict = {i: rand_tensor() for i in "ABC"} + + with temporary_file(mode="wb+") as file: + with file: + serializer = TensorSerializer(file) + serializer.write_state_dict(sd1) + serializer.write_state_dict(sd2) + serializer.write_state_dict(sd3) + serializer.write_state_dict(sd4) + serializer.close() + # Keys 5 and 8 are the (same) long tensor, so remove references + # to them to free up memory for deserialization + del serializer, sd2["5"], sd3["8"] + gc.collect() + rand_tensors: state_dict = {**sd1, **sd2, **sd3, **sd4} + if plentiful_ram: + with TensorDeserializer( + file.name, device="cpu" + ) as deserializer: + self.assertEqual( + self.get_version(deserializer), + serialization.LONG_TENSOR_TENSORIZER_VERSION, + ) + for k, rt in rand_tensors.items(): + self.assertTrue(torch.equal(deserializer[k], rt)) + validate_long_tensor(deserializer["5"]) + validate_long_tensor(deserializer["8"]) + del deserializer + gc.collect() + else: + # Check this twice, but only loading one of the long tensors + # each time, to be light on RAM + for check, skip in ("58", "85"): + with TensorDeserializer( + file.name, + num_readers=1, + device="cpu", + filter_func=lambda name: name != skip, + ) as deserializer: + self.assertEqual( + self.get_version(deserializer), + serialization.LONG_TENSOR_TENSORIZER_VERSION, + ) + for k, rt in rand_tensors.items(): + self.assertTrue(torch.equal(deserializer[k], rt)) + validate_long_tensor(deserializer[check]) + del deserializer + gc.collect() + def test_bfloat16(self): shape = (50, 50) tensor = torch.normal(0, 0.5, shape, dtype=torch.bfloat16)